Mojo struct
PagedRowIndices
struct PagedRowIndices[BN: Int, page_size: Int, pair_cta: Bool = False, is_leader: Bool = True]
Pre-computed physical row indices for a BN-row range of paged KV cache.
BN is V's tile row count. MHAOperand.populate (or its
PagedKVCache override) fills indices for the full BN range (so
V can reuse them); K's TMA (tma_copy_k) covers only a subset
when pair_cta=True (the BN/2 rows owned by this CTA). The K
half is selected at comptime from Self.is_leader: when
num_pages >= 2 the peer shifts its index into rows[] by
num_pages/2; when num_pages == 1 (e.g. page_size >= BN) the
peer reuses rows[0] but adds BN/2 to the issued row.
When page_size >= BN (or page_size == 0 for non-paged), stores a
single entry β zero overhead compared to a single row_idx call.
Under pair_cta=True, K's TMA covers num_pages // 2 entries
(the CTA-rank-specific half) when num_pages >= 2, or the full
single entry when num_pages == 1; V's TMA covers all num_pages.
Storage is sized to V (num_pages = BN / eff_page) regardless of
pair_cta β K populates the full range so V can reuse the rows
without any lazy LUT lookup.
Fieldsβ
- βrows (
InlineArray[UInt32, PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime membersβ
cta_groupβ
comptime cta_group = 2 if pair_cta else 1
eff_pageβ
comptime eff_page = kv_sub_tile_rows(BN, page_size)
num_pagesβ
comptime num_pages = (BN // PagedRowIndices[BN, page_size, pair_cta, is_leader].eff_page)
Methodsβ
__init__β
__init__(out self)
get_rowβ
get_row(self, offset: UInt32) -> UInt32
Physical row for an arbitrary offset within the BN range.
For sub-tile loads: get_row(sub_tile_idx * eff_page).
For depth-512 V: get_row(pv_stage * BK1) avoids re-reading the LUT.
Requires the base kv_row that was passed to populate to be
page-aligned (guaranteed by mask alignment).
Returns:
tma_copy_vβ
tma_copy_v[dtype: DType, tile_shape: IndexList[3], desc_shape: IndexList[3], //, *, needs_partial: Bool, num_v_sub_tiles: Int = 1, v_sub_tile_idx: Int = 0, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_iters: Int = -1](self, tma_op: TMATensorTile[dtype, 3, tile_shape, desc_shape], stage_base: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] mbar: SharedMemBarrier, *, kv_head_idx: UInt32, elect: Int32, num_valid_pages: UInt32 = (PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages // num_v_sub_tiles), depth_offset: UInt32 = UInt32(0))
TMA-copy a V sub-tile, with comptime partial switch.
Consumes pre-populated rows from an earlier MHAOperand.populate
call. In pair_cta mode, that call populates the full num_pages
range (both CTAs' halves) so V can reuse them directly without
any lazy LUT lookup.
num_v_sub_tiles / v_sub_tile_idx select a row sub-range of
the BN tile when V is split across multiple SMEM slots (e.g.
depth512's num_pv_stages=2 split: BK1 = BN/2 rows per
slot). Default (1, 0) loads the full Self.BN rows into a
single SMEM slot of row stride Self.BN β byte-identical to
fa4's previous behavior.
With num_v_sub_tiles > 1:
v_rows_per_sub_tile = Self.BN // num_v_sub_tilesis the SMEM depth-chunk stride (rows per slot).v_tma_tile_rows = kv_sub_tile_rows(v_rows_per_sub_tile, Self.page_size)is the TMA's tile-row count per issue.- When
Self.num_pages >= num_v_sub_tiles: sub-tilesloadsrows[s * v_pages_per_sub_tile .. ). - When
Self.num_pages == 1 < num_v_sub_tiles(page covers the full BN): all sub-tiles sharerows[0]and addv_sub_tile_idx * v_rows_per_sub_tileas intra-page row offset.
needs_partial=False β comptime-unrolled over num_iters
sub-tile entries (default v_pages_per_sub_tile).
needs_partial=True β comptime-unrolls a runtime dispatch that
tests num_valid_pages against each _p in [1, v_pages_per_sub_tile) and tail-calls the needs_partial=False
form with num_iters=_p so the actual TMA issues always emit
as a straight-line, fully static unroll of exactly
num_valid_pages issues. Callers must guarantee
1 <= num_valid_pages <= v_pages_per_sub_tile.
num_iters is an internal dispatch knob: -1 (default) means
"unroll v_pages_per_sub_tile iterations"; any other value
fully unrolls exactly that many. Only the needs_partial=True
wrapper sets it, when it recurses.
elect is the raw Int32 returned by elect(). Each
cp_async_bulk_tensor_shared_cluster_global_elect call predicates
its TMA issue in-PTX on elect, so no Mojo-level if elect != 0:
branch is needed here β all lanes follow the same PTX control
flow and only the elected lane actually issues the TMA.
tma_copy_kβ
tma_copy_k[dtype: DType, tile_shape: IndexList[3], desc_shape: IndexList[3], //, *, needs_partial: Bool, smem_BN: Int = BN, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_iters: Int = -1](self, tma_op: TMATensorTile[dtype, 3, tile_shape, desc_shape], stage_base: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] mbar: SharedMemBarrier, *, kv_head_idx: UInt32, elect: Int32, k_num_valid_pages: UInt32 = (PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages // 2) if pair_cta else PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages, depth_offset: UInt32 = UInt32(0))
TMA-copy K-side rows into scattered smem positions.
K counterpart to tma_copy_v. Loops over
k_pages_per_cta = num_pages // 2 if pair_cta else num_pages
entries, using self.rows[k_idx_offset_ct + _p_k] as the source
row (the index offset is comptime-derived from Self.is_leader
and Self.pair_cta). Smem destination packs the K subset into
the first k_pages_per_cta page slots.
Non-pair-CTA / pair-CTA leader load from entry 0 with no
intra-page offset; pair-CTA peer with num_pages >= 2 shifts
the entry index by num_pages/2; pair-CTA peer with
num_pages == 1 reuses rows[0] but adds BN/2 to the issued
row so it covers the second half of the single page.
smem_BN controls the depth-chunk stride: depth-chunk stride
is smem_BN * swizzle_gran. Defaults to Self.BN (fa4 layout);
depth512 passes Self.BN // 2 = BK1.
needs_partial=False β comptime-unrolled over num_iters
entries (default k_pages_per_cta); k_num_valid_pages is
unused.
needs_partial=True β comptime-unrolls a runtime dispatch that
tests k_num_valid_pages against each _p_k in [1, k_pages_per_cta) and tail-calls the needs_partial=False
form with num_iters=_p_k so the actual TMA issues always
emit as a straight-line, fully static unroll of exactly
k_num_valid_pages issues. Callers must guarantee
1 <= k_num_valid_pages <= k_pages_per_cta.
num_iters is an internal dispatch knob: -1 (default) means
"unroll k_pages_per_cta iterations"; any other value fully
unrolls exactly that many. Only the needs_partial=True
wrapper sets it, when it recurses.
In non-pair_cta mode, k_pages_per_cta == num_pages and the
comptime offsets are zero β full-range behavior.
elect is the raw Int32 returned by elect(). Each
cp_async_bulk_tensor_shared_cluster_global_elect call predicates
its TMA issue in-PTX on elect, so no Mojo-level if elect != 0:
branch is needed β all lanes follow the same PTX control flow and
only the elected lane actually issues the TMA.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!