Mojo struct
KVMmaOp
struct KVMmaOp[in_type: DType, mma_shape: IndexList[3], num_mmas: Int, num_k_mmas: Int, num_k_tiles: Int, BN: Int, BK: Int, transpose_b: Bool = True, swizzle: Optional[Swizzle] = None, out_type: DType = get_accum_type[in_type]()]
Owns the K or V operand register tile and its SMEM→reg load logic.
Attention has two sequential GEMMs (P = Q @ K^T, O += P @ V). Instantiate
one KVMmaOp per operand role. This keeps KVBuffer focused on SMEM
storage + DMA and moves MMA-side concerns (reg layout, frag size,
fragment loads) here.
The register layout is organized as
[num_k_tiles][num_k_mmas][num_mmas] x input_frag_size: each BK strip
holds num_k_mmas * num_mmas fragments back-to-back.
Parameters
- in_type (
DType): Operand element type (bfloat16 or float8_e4m3fn). - mma_shape (
IndexList): MMA instruction shape [M, N, K]. - num_mmas (
Int): MMA tiles along the warp's M or N axis (WN/MMA_M for K). - num_k_mmas (
Int): MMA tiles along K within a single BK strip. - num_k_tiles (
Int): Number of BK strips across the full depth. - BN (
Int): KV block height (needed by V load methods for SMEM offset math). - BK (
Int): KV block width (needed by V load methods for SMEM offset math). - transpose_b (
Bool): True for K (transposed load), False for V. - swizzle (
Optional): Optional SMEM swizzle — vector-space for prefill, element-space for decode. - out_type (
DType): Accumulator data type (defaults to accum(in_type)).
Fields
- reg_tile (
TileTensor[in_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
input_frag_size
comptime input_frag_size = ((KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_K * KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_N) // WARP_SIZE)
MMA_K
comptime MMA_K = mma_shape[2]
MMA_M
comptime MMA_M = mma_shape[0]
MMA_N
comptime MMA_N = mma_shape[1]
Methods
__init__
__init__(out self)
load_prefill
load_prefill[bk_tile: Int](self, warp_smem: TileTensor[in_type, warp_smem.LayoutType, warp_smem.origin, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem.linear_idx_type, element_size=warp_smem.element_size])
Load the bk_tile-th BK strip of K fragments from SMEM.
Delegates to TiledMmaLoader.load_b (M-outer iteration,
vector-space swizzle). Handles both BF16 (single load per MMA
tile) and FP8 (two-half-K load + join) via the num_packs
branch inside load_b.
mma_tile_at
mma_tile_at[bk_tile: Int, kg: Int](self) -> TileTensor[in_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Sub-view of the reg tile for a given (bk_tile, k-group) pair.
Returns:
load_v_bf16
load_v_bf16[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED])
Load the bk_tile-th BK strip of BF16 V fragments from SMEM.
V SMEM is blocked (num_repeats × BN × BK, row-major within each
block). For each (k, i) ∈ [num_k_mmas] × [num_mmas], build an
MMA sub-tile view with the correct (BK, 1) row stride (not
(MMA_M, 1) — see smem_mma_subtile header), call
TiledMmaLoader.load_b_tr, and write the fragment into the
reg slot mma_tile_at[bk_tile, k][i].
Only valid when transpose_b == False and in_type == bfloat16.
load_v_fp8_strip
load_v_fp8_strip[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int)
Load the bk_tile-th BK strip of FP8 V fragments from SMEM.
Iterates dt over the depth direction and calls
TiledMmaLoader.load_v_fp8_strip per (bk_tile, dt) pair,
writing the joined 32-element SIMD into
mma_tile_at[bk_tile, 0][dt].
Only valid when transpose_b == False and in_type is FP8.
Caller precomputes the lane-only coords (rel_key,
hw_key_shift, depth_base) once before a multi-bk loop —
they don't depend on bk_tile or dt.
mma
mma[swap_a_b: Bool = False](self, a: TileTensor[a.dtype, a.LayoutType, a.origin, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], c: TileTensor[c.dtype, c.LayoutType, c.origin, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size], bk_tile: Int, kg: Int)
Compute C += A * B using this op's reg tile as B operand.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!