IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

PagedRowIndices

struct PagedRowIndices[BN: Int, page_size: Int, pair_cta: Bool = False, is_leader: Bool = True]

Pre-computed physical row indices for a BN-row range of paged KV cache.

BN is V's tile row count. MHAOperand.populate (or its PagedKVCache override) fills indices for the full BN range (so V can reuse them); K's TMA (tma_copy_k) covers only a subset when pair_cta=True (the BN/2 rows owned by this CTA). The K half is selected at comptime from Self.is_leader: when num_pages >= 2 the peer shifts its index into rows[] by num_pages/2; when num_pages == 1 (e.g. page_size >= BN) the peer reuses rows[0] but adds BN/2 to the issued row.

When page_size >= BN (or page_size == 0 for non-paged), stores a single entry β€” zero overhead compared to a single row_idx call.

Under pair_cta=True, K's TMA covers num_pages // 2 entries (the CTA-rank-specific half) when num_pages >= 2, or the full single entry when num_pages == 1; V's TMA covers all num_pages. Storage is sized to V (num_pages = BN / eff_page) regardless of pair_cta β€” K populates the full range so V can reuse the rows without any lazy LUT lookup.

Fields​

  • ​rows (InlineArray[UInt32, (BN // kv_sub_tile_rows(BN, page_size))]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable

comptime members​

cta_group​

comptime cta_group = Int(2) if pair_cta else Int(1)

eff_page​

comptime eff_page = kv_sub_tile_rows(BN, page_size)

num_pages​

comptime num_pages = (BN // kv_sub_tile_rows(BN, page_size))

Methods​

__init__​

def __init__(out self)

get_row​

def get_row(self, offset: UInt32) -> UInt32

Physical row for an arbitrary offset within the BN range.

For sub-tile loads: get_row(sub_tile_idx * eff_page). For depth-512 V: get_row(pv_stage * BK1) avoids re-reading the LUT. Requires the base kv_row that was passed to populate to be page-aligned (guaranteed by mask alignment).

Returns:

UInt32

tma_copy_v​

def tma_copy_v[dtype: DType, tile_shape: IndexList[Int(3)], desc_shape: IndexList[Int(3)], //, *, needs_partial: Bool, num_v_sub_tiles: Int = Int(1), v_sub_tile_idx: Int = Int(0), eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_iters: Int = Int(-1), oob_fill_pages: Bool = False](self, tma_op: TMATensorTile[dtype, Int(3), tile_shape, desc_shape], stage_base: UnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED], ref[AddressSpace._value] mbar: SharedMemBarrier, *, kv_head_idx: UInt32, elect: Int32, num_valid_pages: UInt32 = SIMD(((BN // kv_sub_tile_rows(BN, page_size)) // num_v_sub_tiles)), depth_offset: UInt32 = UInt32(0))

TMA-copy a V sub-tile, with comptime partial switch.

Consumes pre-populated rows from an earlier MHAOperand.populate call. In pair_cta mode, that call populates the full num_pages range (both CTAs' halves) so V can reuse them directly without any lazy LUT lookup.

num_v_sub_tiles / v_sub_tile_idx select a row sub-range of the BN tile when V is split across multiple SMEM slots (e.g. depth512's num_pv_stages=2 split: BK1 = BN/2 rows per slot). Default (1, 0) loads the full Self.BN rows into a single SMEM slot of row stride Self.BN β€” byte-identical to fa4's previous behavior.

With num_v_sub_tiles > 1:

  • v_rows_per_sub_tile = Self.BN // num_v_sub_tiles is the SMEM depth-chunk stride (rows per slot).
  • v_tma_tile_rows = kv_sub_tile_rows(v_rows_per_sub_tile, Self.page_size) is the TMA's tile-row count per issue.
  • When Self.num_pages >= num_v_sub_tiles: sub-tile s loads rows[s * v_pages_per_sub_tile .. ).
  • When Self.num_pages == 1 < num_v_sub_tiles (page covers the full BN): all sub-tiles share rows[0] and add v_sub_tile_idx * v_rows_per_sub_tile as intra-page row offset.

needs_partial=False β€” comptime-unrolled over num_iters sub-tile entries (default v_pages_per_sub_tile).

needs_partial=True β€” comptime-unrolls a runtime dispatch that tests num_valid_pages against each _p in [1, v_pages_per_sub_tile) and tail-calls the needs_partial=False form with num_iters=_p so the actual TMA issues always emit as a straight-line, fully static unroll of exactly num_valid_pages issues. Callers must guarantee 1 <= num_valid_pages <= v_pages_per_sub_tile.

num_iters is an internal dispatch knob: -1 (default) means "unroll v_pages_per_sub_tile iterations"; any other value fully unrolls exactly that many. Only the needs_partial=True wrapper sets it, when it recurses.

oob_fill_pages (consulted only when needs_partial=True): when True, after dispatching the num_valid_pages valid TMAs, also issue OOB TMAs for the remaining [num_valid_pages, v_pages_per_sub_tile) page slots. The TMA descriptor's OOBFill.NONE policy zero-fills SMEM for OOB coordinates, ensuring the full V-tile region holds finite (0) data β€” required by depth-512 FA4 whose O += P * V reads the full BN V-tile and would otherwise propagate 0 * non-finite = NaN from uninitialized SMEM (the bug only materializes when this is the very first write to the SMEM slot β€” typically seq_len <= BN so the only iter is partial). Callers opting in MUST predicate expect_bytes on the full (non-partial) byte count; every v_pages_per_sub_tile * num_depth_chunks TMA arrives at the mbar.

elect is the raw Int32 returned by elect(). Each cp_async_bulk_tensor_shared_cluster_global_elect call predicates its TMA issue in-PTX on elect, so no Mojo-level if elect != 0: branch is needed here β€” all lanes follow the same PTX control flow and only the elected lane actually issues the TMA.

tma_copy_k​

def tma_copy_k[dtype: DType, tile_shape: IndexList[Int(3)], desc_shape: IndexList[Int(3)], //, *, needs_partial: Bool, smem_BN: Int = BN, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_iters: Int = Int(-1)](self, tma_op: TMATensorTile[dtype, Int(3), tile_shape, desc_shape], stage_base: UnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED], ref[AddressSpace._value] mbar: SharedMemBarrier, *, kv_head_idx: UInt32, elect: Int32, k_num_valid_pages: UInt32 = SIMD(((BN // kv_sub_tile_rows(BN, page_size)) // Int(2)) if pair_cta else (BN // kv_sub_tile_rows(BN, page_size))), depth_offset: UInt32 = UInt32(0))

TMA-copy K-side rows into scattered smem positions.

K counterpart to tma_copy_v. Loops over k_pages_per_cta = num_pages // 2 if pair_cta else num_pages entries, using self.rows[k_idx_offset_ct + _p_k] as the source row (the index offset is comptime-derived from Self.is_leader and Self.pair_cta). Smem destination packs the K subset into the first k_pages_per_cta page slots.

Non-pair-CTA / pair-CTA leader load from entry 0 with no intra-page offset; pair-CTA peer with num_pages >= 2 shifts the entry index by num_pages/2; pair-CTA peer with num_pages == 1 reuses rows[0] but adds BN/2 to the issued row so it covers the second half of the single page.

smem_BN controls the depth-chunk stride: depth-chunk stride is smem_BN * swizzle_gran. Defaults to Self.BN (fa4 layout); depth512 passes Self.BN // 2 = BK1.

needs_partial=False β€” comptime-unrolled over num_iters entries (default k_pages_per_cta); k_num_valid_pages is unused.

needs_partial=True β€” comptime-unrolls a runtime dispatch that tests k_num_valid_pages against each _p_k in [1, k_pages_per_cta) and tail-calls the needs_partial=False form with num_iters=_p_k so the actual TMA issues always emit as a straight-line, fully static unroll of exactly k_num_valid_pages issues. Callers must guarantee 1 <= k_num_valid_pages <= k_pages_per_cta.

num_iters is an internal dispatch knob: -1 (default) means "unroll k_pages_per_cta iterations"; any other value fully unrolls exactly that many. Only the needs_partial=True wrapper sets it, when it recurses.

In non-pair_cta mode, k_pages_per_cta == num_pages and the comptime offsets are zero β€” full-range behavior.

elect is the raw Int32 returned by elect(). Each cp_async_bulk_tensor_shared_cluster_global_elect call predicates its TMA issue in-PTX on elect, so no Mojo-level if elect != 0: branch is needed β€” all lanes follow the same PTX control flow and only the elected lane actually issues the TMA.