Skip to main content

Mojo struct

KVMmaOp

struct KVMmaOp[in_type: DType, mma_shape: IndexList[3], num_mmas: Int, num_k_mmas: Int, num_k_tiles: Int, BN: Int, BK: Int, transpose_b: Bool = True, swizzle: Optional[Swizzle] = None, out_type: DType = get_accum_type[in_type]()]

Owns the K or V operand register tile and its SMEM→reg load logic.

Attention has two sequential GEMMs (P = Q @ K^T, O += P @ V). Instantiate one KVMmaOp per operand role. This keeps KVBuffer focused on SMEM storage + DMA and moves MMA-side concerns (reg layout, frag size, fragment loads) here.

The register layout is organized as [num_k_tiles][num_k_mmas][num_mmas] x input_frag_size: each BK strip holds num_k_mmas * num_mmas fragments back-to-back.

Parameters

  • in_type (DType): Operand element type (bfloat16 or float8_e4m3fn).
  • mma_shape (IndexList): MMA instruction shape [M, N, K].
  • num_mmas (Int): MMA tiles along the warp's M or N axis (WN/MMA_M for K).
  • num_k_mmas (Int): MMA tiles along K within a single BK strip.
  • num_k_tiles (Int): Number of BK strips across the full depth.
  • BN (Int): KV block height (needed by V load methods for SMEM offset math).
  • BK (Int): KV block width (needed by V load methods for SMEM offset math).
  • transpose_b (Bool): True for K (transposed load), False for V.
  • swizzle (Optional): Optional SMEM swizzle — vector-space for prefill, element-space for decode.
  • out_type (DType): Accumulator data type (defaults to accum(in_type)).

Fields

  • reg_tile (TileTensor[in_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

input_frag_size

comptime input_frag_size = ((KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_K * KVMmaOp[in_type, mma_shape, num_mmas, num_k_mmas, num_k_tiles, BN, BK, transpose_b, swizzle, out_type].MMA_N) // WARP_SIZE)

MMA_K

comptime MMA_K = mma_shape[2]

MMA_M

comptime MMA_M = mma_shape[0]

MMA_N

comptime MMA_N = mma_shape[1]

Methods

__init__

__init__(out self)

load_prefill

load_prefill[bk_tile: Int](self, warp_smem: TileTensor[in_type, warp_smem.LayoutType, warp_smem.origin, address_space=AddressSpace.SHARED, linear_idx_type=warp_smem.linear_idx_type, element_size=warp_smem.element_size])

Load the bk_tile-th BK strip of K fragments from SMEM.

Delegates to TiledMmaLoader.load_b (M-outer iteration, vector-space swizzle). Handles both BF16 (single load per MMA tile) and FP8 (two-half-K load + join) via the num_packs branch inside load_b.

mma_tile_at

mma_tile_at[bk_tile: Int, kg: Int](self) -> TileTensor[in_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Sub-view of the reg tile for a given (bk_tile, k-group) pair.

Returns:

TileTensor

load_v_bf16

load_v_bf16[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED])

Load the bk_tile-th BK strip of BF16 V fragments from SMEM.

V SMEM is blocked (num_repeats × BN × BK, row-major within each block). For each (k, i) ∈ [num_k_mmas] × [num_mmas], build an MMA sub-tile view with the correct (BK, 1) row stride (not (MMA_M, 1) — see smem_mma_subtile header), call TiledMmaLoader.load_b_tr, and write the fragment into the reg slot mma_tile_at[bk_tile, k][i].

Only valid when transpose_b == False and in_type == bfloat16.

load_v_fp8_strip

load_v_fp8_strip[bk_tile: Int](self, smem_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int)

Load the bk_tile-th BK strip of FP8 V fragments from SMEM.

Iterates dt over the depth direction and calls TiledMmaLoader.load_v_fp8_strip per (bk_tile, dt) pair, writing the joined 32-element SIMD into mma_tile_at[bk_tile, 0][dt].

Only valid when transpose_b == False and in_type is FP8. Caller precomputes the lane-only coords (rel_key, hw_key_shift, depth_base) once before a multi-bk loop — they don't depend on bk_tile or dt.

mma

mma[swap_a_b: Bool = False](self, a: TileTensor[a.dtype, a.LayoutType, a.origin, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], c: TileTensor[c.dtype, c.LayoutType, c.origin, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size], bk_tile: Int, kg: Int)

Compute C += A * B using this op's reg tile as B operand.

Was this page helpful?