Skip to main content

Mojo struct

KVMmaOp

struct KVMmaOp[in_type: DType, mma_shape: IndexList[3], num_mmas: Int, num_k_mmas: Int, num_k_tiles: Int, BN: Int, BK: Int, transpose_b: Bool = True, swizzle: Optional[Swizzle] = None, out_type: DType = get_accum_type[in_type]()]

Owns the K or V operand register tile and its SMEM→reg load logic.

Attention has two sequential GEMMs (P = Q @ K^T, O += P @ V). Instantiate one KVMmaOp per operand role. This keeps KVBuffer focused on SMEM storage + DMA and moves MMA-side concerns (reg layout, frag size, fragment loads) here.

The register layout is organized as [num_k_tiles][num_k_mmas][num_mmas] x input_frag_size: each BK strip holds num_k_mmas * num_mmas fragments back-to-back.

Parameters​

  • ​in_type (DType): Operand element type (bfloat16 or float8_e4m3fn).
  • ​mma_shape (IndexList[3]): MMA instruction shape [M, N, K].
  • ​num_mmas (Int): MMA tiles along the warp's M or N axis (WN/MMA_M for K).
  • ​num_k_mmas (Int): MMA tiles along K within a single BK strip.
  • ​num_k_tiles (Int): Number of BK strips across the full depth.
  • ​BN (Int): KV block height (needed by V load methods for SMEM offset math).
  • ​BK (Int): KV block width (needed by V load methods for SMEM offset math).
  • ​transpose_b (Bool): True for K (transposed load), False for V.
  • ​swizzle (Optional[Swizzle]): Optional SMEM swizzle β€” vector-space for prefill, element-space for decode.
  • ​out_type (DType): Accumulator data type (defaults to accum(in_type)).

Fields​

  • ​reg_tile (TileTensor[in_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

input_frag_size​

comptime input_frag_size = ((KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_K * KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_N) // WARP_SIZE)

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_M​

comptime MMA_M = mma_shape[0]

MMA_N​

comptime MMA_N = mma_shape[1]

Methods​

__init__​

__init__(out self)

load_prefill​

load_prefill[bk_tile: Int](self, warp_smem: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem.linear_idx_type, element_size=warp_smem.element_size])

Load the bk_tile-th BK strip of K fragments from SMEM.

Delegates to TiledMmaLoader.load_b (M-outer iteration, vector-space swizzle). Handles both BF16 (single load per MMA tile) and FP8 (two-half-K load + join) via the num_packs branch inside load_b.

mma_tile_at​

mma_tile_at[bk_tile: Int, kg: Int](self) -> TileTensor[in_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Sub-view of the reg tile for a given (bk_tile, k-group) pair.

Returns:

TileTensor[in_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

load_v_bf16​

load_v_bf16[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED])

Load the bk_tile-th BK strip of BF16 V fragments from SMEM.

V SMEM is blocked (num_repeats Γ— BN Γ— BK, row-major within each block). For each (k, i) ∈ [num_k_mmas] Γ— [num_mmas], build an MMA sub-tile view with the correct (BK, 1) row stride (not (MMA_M, 1) β€” see smem_mma_subtile header), call TiledMmaLoader.load_b_tr, and write the fragment into the reg slot mma_tile_at[bk_tile, k][i].

Only valid when transpose_b == False and in_type == bfloat16.

load_v_fp8_strip​

load_v_fp8_strip[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int)

Load the bk_tile-th BK strip of FP8 V fragments from SMEM.

Iterates dt over the depth direction and calls TiledMmaLoader.load_v_fp8_strip per (bk_tile, dt) pair, writing the joined 32-element SIMD into mma_tile_at[bk_tile, 0][dt].

Only valid when transpose_b == False and in_type is FP8. Caller precomputes the lane-only coords (rel_key, hw_key_shift, depth_base) once before a multi-bk loop β€” they don't depend on bk_tile or dt.

mma​

mma[swap_a_b: Bool = False](self, a: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], c: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size], bk_tile: Int, kg: Int)

Compute C += A * B using this op's reg tile as B operand.