Skip to main content

Mojo struct

PagedRowIndices

struct PagedRowIndices[BN: Int, page_size: Int, pair_cta: Bool = False, is_leader: Bool = True]

Pre-computed physical row indices for a BN-row range of paged KV cache.

BN is V's tile row count. MHAOperand.populate (or its PagedKVCache override) fills indices for the full BN range (so V can reuse them); K's TMA (tma_copy_k) covers only a subset when pair_cta=True (the BN/2 rows owned by this CTA). The K half is selected at comptime from Self.is_leader: when num_pages >= 2 the peer shifts its index into rows[] by num_pages/2; when num_pages == 1 (e.g. page_size >= BN) the peer reuses rows[0] but adds BN/2 to the issued row.

When page_size >= BN (or page_size == 0 for non-paged), stores a single entry β€” zero overhead compared to a single row_idx call.

Under pair_cta=True, K's TMA covers num_pages // 2 entries (the CTA-rank-specific half) when num_pages >= 2, or the full single entry when num_pages == 1; V's TMA covers all num_pages. Storage is sized to V (num_pages = BN / eff_page) regardless of pair_cta β€” K populates the full range so V can reuse the rows without any lazy LUT lookup.

Fields​

  • ​rows (InlineArray[UInt32, PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members​

cta_group​

comptime cta_group = 2 if pair_cta else 1

eff_page​

comptime eff_page = kv_sub_tile_rows(BN, page_size)

num_pages​

comptime num_pages = (BN // PagedRowIndices[BN, page_size, pair_cta, is_leader].eff_page)

Methods​

__init__​

__init__(out self)

get_row​

get_row(self, offset: UInt32) -> UInt32

Physical row for an arbitrary offset within the BN range.

For sub-tile loads: get_row(sub_tile_idx * eff_page). For depth-512 V: get_row(pv_stage * BK1) avoids re-reading the LUT. Requires the base kv_row that was passed to populate to be page-aligned (guaranteed by mask alignment).

Returns:

UInt32

tma_copy_v​

tma_copy_v[dtype: DType, tile_shape: IndexList[3], desc_shape: IndexList[3], //, *, needs_partial: Bool, num_v_sub_tiles: Int = 1, v_sub_tile_idx: Int = 0, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_iters: Int = -1](self, tma_op: TMATensorTile[dtype, 3, tile_shape, desc_shape], stage_base: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] mbar: SharedMemBarrier, *, kv_head_idx: UInt32, elect: Int32, num_valid_pages: UInt32 = (PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages // num_v_sub_tiles), depth_offset: UInt32 = UInt32(0))

TMA-copy a V sub-tile, with comptime partial switch.

Consumes pre-populated rows from an earlier MHAOperand.populate call. In pair_cta mode, that call populates the full num_pages range (both CTAs' halves) so V can reuse them directly without any lazy LUT lookup.

num_v_sub_tiles / v_sub_tile_idx select a row sub-range of the BN tile when V is split across multiple SMEM slots (e.g. depth512's num_pv_stages=2 split: BK1 = BN/2 rows per slot). Default (1, 0) loads the full Self.BN rows into a single SMEM slot of row stride Self.BN β€” byte-identical to fa4's previous behavior.

With num_v_sub_tiles > 1:

  • v_rows_per_sub_tile = Self.BN // num_v_sub_tiles is the SMEM depth-chunk stride (rows per slot).
  • v_tma_tile_rows = kv_sub_tile_rows(v_rows_per_sub_tile, Self.page_size) is the TMA's tile-row count per issue.
  • When Self.num_pages >= num_v_sub_tiles: sub-tile s loads rows[s * v_pages_per_sub_tile .. ).
  • When Self.num_pages == 1 < num_v_sub_tiles (page covers the full BN): all sub-tiles share rows[0] and add v_sub_tile_idx * v_rows_per_sub_tile as intra-page row offset.

needs_partial=False β€” comptime-unrolled over num_iters sub-tile entries (default v_pages_per_sub_tile).

needs_partial=True β€” comptime-unrolls a runtime dispatch that tests num_valid_pages against each _p in [1, v_pages_per_sub_tile) and tail-calls the needs_partial=False form with num_iters=_p so the actual TMA issues always emit as a straight-line, fully static unroll of exactly num_valid_pages issues. Callers must guarantee 1 <= num_valid_pages <= v_pages_per_sub_tile.

num_iters is an internal dispatch knob: -1 (default) means "unroll v_pages_per_sub_tile iterations"; any other value fully unrolls exactly that many. Only the needs_partial=True wrapper sets it, when it recurses.

elect is the raw Int32 returned by elect(). Each cp_async_bulk_tensor_shared_cluster_global_elect call predicates its TMA issue in-PTX on elect, so no Mojo-level if elect != 0: branch is needed here β€” all lanes follow the same PTX control flow and only the elected lane actually issues the TMA.

tma_copy_k​

tma_copy_k[dtype: DType, tile_shape: IndexList[3], desc_shape: IndexList[3], //, *, needs_partial: Bool, smem_BN: Int = BN, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_iters: Int = -1](self, tma_op: TMATensorTile[dtype, 3, tile_shape, desc_shape], stage_base: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], ref[AddressSpace._value] mbar: SharedMemBarrier, *, kv_head_idx: UInt32, elect: Int32, k_num_valid_pages: UInt32 = (PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages // 2) if pair_cta else PagedRowIndices[BN, page_size, pair_cta, is_leader].num_pages, depth_offset: UInt32 = UInt32(0))

TMA-copy K-side rows into scattered smem positions.

K counterpart to tma_copy_v. Loops over k_pages_per_cta = num_pages // 2 if pair_cta else num_pages entries, using self.rows[k_idx_offset_ct + _p_k] as the source row (the index offset is comptime-derived from Self.is_leader and Self.pair_cta). Smem destination packs the K subset into the first k_pages_per_cta page slots.

Non-pair-CTA / pair-CTA leader load from entry 0 with no intra-page offset; pair-CTA peer with num_pages >= 2 shifts the entry index by num_pages/2; pair-CTA peer with num_pages == 1 reuses rows[0] but adds BN/2 to the issued row so it covers the second half of the single page.

smem_BN controls the depth-chunk stride: depth-chunk stride is smem_BN * swizzle_gran. Defaults to Self.BN (fa4 layout); depth512 passes Self.BN // 2 = BK1.

needs_partial=False β€” comptime-unrolled over num_iters entries (default k_pages_per_cta); k_num_valid_pages is unused.

needs_partial=True β€” comptime-unrolls a runtime dispatch that tests k_num_valid_pages against each _p_k in [1, k_pages_per_cta) and tail-calls the needs_partial=False form with num_iters=_p_k so the actual TMA issues always emit as a straight-line, fully static unroll of exactly k_num_valid_pages issues. Callers must guarantee 1 <= k_num_valid_pages <= k_pages_per_cta.

num_iters is an internal dispatch knob: -1 (default) means "unroll k_pages_per_cta iterations"; any other value fully unrolls exactly that many. Only the needs_partial=True wrapper sets it, when it recurses.

In non-pair_cta mode, k_pages_per_cta == num_pages and the comptime offsets are zero β€” full-range behavior.

elect is the raw Int32 returned by elect(). Each cp_async_bulk_tensor_shared_cluster_global_elect call predicates its TMA issue in-PTX on elect, so no Mojo-level if elect != 0: branch is needed β€” all lanes follow the same PTX control flow and only the elected lane actually issues the TMA.