IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

KVMmaOp

struct KVMmaOp[in_type: DType, mma_shape: IndexList[3], num_mmas: Int, num_k_mmas: Int, num_k_tiles: Int, BN: Int, BK: Int, transpose_b: Bool = True, swizzle: Optional[Swizzle] = None, out_type: DType = get_accum_type[in_type]()]

Owns the K or V operand register tile and its SMEM→reg load logic.

Attention has two sequential GEMMs (P = Q @ K^T, O += P @ V). Instantiate one KVMmaOp per operand role. This keeps KVBuffer focused on SMEM storage + DMA and moves MMA-side concerns (reg layout, frag size, fragment loads) here.

The register layout is organized as [num_k_tiles][num_k_mmas][num_mmas] x input_frag_size: each BK strip holds num_k_mmas * num_mmas fragments back-to-back.

Parameters​

  • ​in_type (DType): Operand element type (bfloat16 or float8_e4m3fn).
  • ​mma_shape (IndexList[3]): MMA instruction shape [M, N, K].
  • ​num_mmas (Int): MMA tiles along the warp's M or N axis (WN/MMA_M for K).
  • ​num_k_mmas (Int): MMA tiles along K within a single BK strip.
  • ​num_k_tiles (Int): Number of BK strips across the full depth.
  • ​BN (Int): KV block height (needed by V load methods for SMEM offset math).
  • ​BK (Int): KV block width (needed by V load methods for SMEM offset math).
  • ​transpose_b (Bool): True for K (transposed load), False for V.
  • ​swizzle (Optional[Swizzle]): Optional SMEM swizzle β€” vector-space for prefill, element-space for decode.
  • ​out_type (DType): Accumulator data type (defaults to accum(in_type)).

Fields​

  • ​reg_tile (TileTensor[in_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

input_frag_size​

comptime input_frag_size = ((KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_K * KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_N) // WARP_SIZE)

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_M​

comptime MMA_M = mma_shape[0]

MMA_N​

comptime MMA_N = mma_shape[1]

Methods​

__init__​

def __init__(out self)

load_prefill​

def load_prefill[bk_tile: Int](self, warp_smem: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem.linear_idx_type, element_size=warp_smem.element_size])

Load the bk_tile-th BK strip of K fragments from SMEM.

Delegates to TiledMmaLoader.load_b (M-outer iteration, vector-space swizzle). Handles both BF16 (single load per MMA tile) and FP8 (two-half-K load + join) via the num_packs branch inside load_b.

load_prefill_split​

def load_prefill_split[bk_tile: Int, has_hi: Bool](self, warp_smem_lo: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem_lo.linear_idx_type, element_size=warp_smem_lo.element_size], warp_smem_hi: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem_hi.linear_idx_type, element_size=warp_smem_hi.element_size])

Load the bk_tile-th MMA K=128 strip composed of two adjacent WN Γ— (MMA_K/2) SMEM blocks.

Used by FP8 16x16x128 MLA decode when the SMEM block width (bk_smem) is MMA_K/2 = 64 instead of BK = 128. Each MMA strip's two K-halves live in two separate BN Γ— 64 SMEM blocks (matching the no-pad K layout at depth=576).

When has_hi == False the hi half is register-zero β€” this handles the partial K-tile at the rope tail (strip 4 for depth=576: lo holds depth 512..575 from block 8, hi has no backing block, MMA upper-half lane registers get 0).

mma_tile_at​

def mma_tile_at[bk_tile: Int, kg: Int](self) -> TileTensor[in_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Sub-view of the reg tile for a given (bk_tile, k-group) pair.

Returns:

TileTensor[in_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

load_v_bf16​

def load_v_bf16[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED])

Load the bk_tile-th BK strip of BF16 V fragments from SMEM.

V SMEM is blocked (num_repeats Γ— BN Γ— BK, row-major within each block). For each (k, i) ∈ [num_k_mmas] Γ— [num_mmas], build an MMA sub-tile view with the correct (BK, 1) row stride (not (MMA_M, 1) β€” see smem_mma_subtile header), call TiledMmaLoader.load_b_tr, and write the fragment into the reg slot mma_tile_at[bk_tile, k][i].

Only valid when transpose_b == False and in_type == bfloat16.

load_v_fp8_strip​

def load_v_fp8_strip[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int)

Load the bk_tile-th BK strip of FP8 V fragments from SMEM.

Iterates dt over the depth direction and calls TiledMmaLoader.load_v_fp8_strip per (bk_tile, dt) pair, writing the joined 32-element SIMD into mma_tile_at[bk_tile, 0][dt].

Only valid when transpose_b == False and in_type is FP8. Caller precomputes the lane-only coords (rel_key, hw_key_shift, depth_base) once before a multi-bk loop β€” they don't depend on bk_tile or dt.

mma​

def mma[swap_a_b: Bool = False](self, a: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], c: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size], bk_tile: Int, kg: Int)

Compute C += A * B using this op's reg tile as B operand.