For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
KVMmaOp
struct KVMmaOp[in_type: DType, mma_shape: IndexList[Int(3)], num_mmas: Int, num_k_mmas: Int, num_k_tiles: Int, BN: Int, BK: Int, transpose_b: Bool = True, swizzle: Optional[Swizzle] = None, out_type: DType = get_accum_type[in_type]()]
Owns the K or V operand register tile and its SMEMβreg load logic.
Attention has two sequential GEMMs (P = Q @ K^T, O += P @ V). Instantiate
one KVMmaOp per operand role. This keeps KVBuffer focused on SMEM
storage + DMA and moves MMA-side concerns (reg layout, frag size,
fragment loads) here.
The register layout is organized as
[num_k_tiles][num_k_mmas][num_mmas] x input_frag_size: each BK strip
holds num_k_mmas * num_mmas fragments back-to-back.
Parametersβ
- βin_type (
DType): Operand element type (bfloat16 or float8_e4m3fn). - βmma_shape (
IndexList[Int(3)]): MMA instruction shape [M, N, K]. - βnum_mmas (
Int): MMA tiles along the warp's M or N axis (WN/MMA_M for K). - βnum_k_mmas (
Int): MMA tiles along K within a single BK strip. - βnum_k_tiles (
Int): Number of BK strips across the full depth. - βBN (
Int): KV block height (needed by V load methods for SMEM offset math). - βBK (
Int): KV block width (needed by V load methods for SMEM offset math). - βtranspose_b (
Bool): True for K (transposed load), False for V. - βswizzle (
Optional[Swizzle]): Optional SMEM swizzle β vector-space for prefill, element-space for decode. - βout_type (
DType): Accumulator data type (defaults to accum(in_type)).
Fieldsβ
- βreg_tile (
TileTensor[in_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]):
Implemented traitsβ
comptime membersβ
input_frag_sizeβ
comptime input_frag_size = (Int((mul mma_shape[Int(2)], mma_shape[Int(1)])) // _resolve_warp_size())
MMA_Kβ
comptime MMA_K = mma_shape[Int(2)]
MMA_Mβ
comptime MMA_M = mma_shape[Int(0)]
MMA_Nβ
comptime MMA_N = mma_shape[Int(1)]
Methodsβ
__init__β
def __init__(out self)
load_prefillβ
def load_prefill[bk_tile: Int](self, warp_smem: TileTensor[in_type, Storage=warp_smem.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem.linear_idx_type, element_size=warp_smem.element_size])
Load the bk_tile-th BK strip of K fragments from SMEM.
Delegates to TiledMmaLoader.load_b (M-outer iteration,
vector-space swizzle). Handles both BF16 (single load per MMA
tile) and FP8 (two-half-K load + join) via the num_packs
branch inside load_b.
load_prefill_splitβ
def load_prefill_split[bk_tile: Int, has_hi: Bool](self, warp_smem_lo: TileTensor[in_type, Storage=warp_smem_lo.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem_lo.linear_idx_type, element_size=warp_smem_lo.element_size], warp_smem_hi: TileTensor[in_type, Storage=warp_smem_hi.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem_hi.linear_idx_type, element_size=warp_smem_hi.element_size])
Load the bk_tile-th MMA K=128 strip composed of two adjacent WN Γ (MMA_K/2) SMEM blocks.
Used by FP8 16x16x128 MLA decode when the SMEM block width
(bk_smem) is MMA_K/2 = 64 instead of BK = 128. Each MMA
strip's two K-halves live in two separate BN Γ 64 SMEM blocks
(matching the no-pad K layout at depth=576).
When has_hi == False the hi half is register-zero β this
handles the partial K-tile at the rope tail (strip 4 for
depth=576: lo holds depth 512..575 from block 8, hi has no
backing block, MMA upper-half lane registers get 0).
mma_tile_atβ
def mma_tile_at[bk_tile: Int, kg: Int](self) -> TileTensor[in_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Sub-view of the reg tile for a given (bk_tile, k-group) pair.
Returns:
TileTensor[in_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
load_v_bf16β
def load_v_bf16[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED])
Load the bk_tile-th BK strip of BF16 V fragments from SMEM.
V SMEM is blocked (num_repeats Γ BN Γ BK, row-major within each
block). For each (k, i) β [num_k_mmas] Γ [num_mmas], build an
MMA sub-tile view with the correct (BK, 1) row stride (not
(MMA_M, 1) β see smem_mma_subtile header), call
TiledMmaLoader.load_b_tr, and write the fragment into the
reg slot mma_tile_at[bk_tile, k][i].
Only valid when transpose_b == False and in_type == bfloat16.
load_v_fp8_stripβ
def load_v_fp8_strip[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int)
Load the bk_tile-th BK strip of FP8 V fragments from SMEM.
Iterates dt over the depth direction and calls
TiledMmaLoader.load_v_fp8_strip per (bk_tile, dt) pair,
writing the joined 32-element SIMD into
mma_tile_at[bk_tile, 0][dt].
Only valid when transpose_b == False and in_type is FP8.
Caller precomputes the lane-only coords (rel_key,
hw_key_shift, depth_base) once before a multi-bk loop β
they don't depend on bk_tile or dt.
mmaβ
def mma[swap_a_b: Bool = False](self, a: TileTensor[Storage=a.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], c: TileTensor[Storage=c.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size], bk_tile: Int, kg: Int)
Compute C += A * B using this op's reg tile as B operand.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!