Mojo struct
MhaMmaOp
struct MhaMmaOp[T: DType, config: HKMhaConfig]
Namespace-style struct holding the shape constants, register-tile layouts, and SMEM→register loaders for HKMhaPrefill. All call sites go through static methods on this struct.
Specialized for v_mfma_f32_32x32x16_bf16; the MFMA shape, SMEM
sub-block geometry, and per-lane fragment decomposition are
hard-coded inside (the rest of the kernel — swizzle, sub-tile
layouts, ds_read variants — is wired to this exact MFMA).
Parameters
- T (
DType): Element data type (BF16). - config (
HKMhaConfig): Shape configuration.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
ATT_BF16_FULL_LAYOUT
comptime ATT_BF16_FULL_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 16), (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 8]()
Full att BF16 tile pre-cast from FP32 (indexed by subtile_idx to feed mma_PV strip-by-strip).
ATT_BF16_SUB_LAYOUT
comptime ATT_BF16_SUB_LAYOUT = row_major[1, (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 8]()
One PV-A subtile (16-row strip of att, BF16).
ATT_LAYOUT
comptime ATT_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 32), (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 16]()
Attention block (QK output, col_l rt_32x32 FP32).
DEPTH
comptime DEPTH = config.depth
FRAG_ELTS
comptime FRAG_ELTS = 8
BF16 elements per lane per MFMA base tile (MMA_M * MMA_K / 64).
K_LAYOUT
comptime K_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 32), (MhaMmaOp[T, config].DEPTH // 16), 8]()
K register tile (whole K, pre-loaded across cluster boundaries).
K_SUB_COLS
comptime K_SUB_COLS = 32
K SMEM sub-block cols.
K_SUB_ROWS
comptime K_SUB_ROWS = 32
K SMEM sub-block rows (two-XOR swizzle).
KV_BLOCK
comptime KV_BLOCK = config.kv_block
MMA_K
comptime MMA_K = 16
MMA_M
comptime MMA_M = 32
MMA_N
comptime MMA_N = 32
O_LAYOUT
comptime O_LAYOUT = row_major[(MhaMmaOp[T, config].DEPTH // 32), (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 16]()
Output accumulator (col_l rt_32x32 FP32).
O_T_LAYOUT
comptime O_T_LAYOUT = row_major[(MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), (MhaMmaOp[T, config].DEPTH // 32), 16]()
Output transpose (row_l view of the same storage as O_LAYOUT).
Q_BLOCK_SIZE
comptime Q_BLOCK_SIZE = config.q_block_size
Q_LAYOUT
comptime Q_LAYOUT = row_major[(MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), (MhaMmaOp[T, config].DEPTH // 16), 8]()
Q register tile.
ROWL_HALF_LANES
comptime ROWL_HALF_LANES = 32
Lanes per half-warp in the row_l rt_32x16 decomposition. Lanes [0, 32) own one row, [32, 64) own a shifted col block.
ROWL_STRIDE
comptime ROWL_STRIDE = 8
Per-lane fragment width in the row_l rt_32x16 decomposition (MMA_K // 2): half-warp partitioning packs two strips per base tile.
V_LAYOUT
comptime V_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 16), (MhaMmaOp[T, config].DEPTH // 32), 8]()
V register tile.
V_SUB_COLS
comptime V_SUB_COLS = 32
V SMEM sub-block cols.
V_SUB_ROWS
comptime V_SUB_ROWS = 8
V SMEM sub-block rows (identity swizzle).
Methods
load_K
static load_K[layout_dst: TensorLayout, layout_src: TensorLayout, //](mut dst: TileTensor[T, layout_dst, MutExternalOrigin, address_space=AddressSpace.LOCAL], src: TileTensor[T, layout_src, MutAnyOrigin, address_space=AddressSpace.SHARED])
Loads the whole (KV_BLOCK, DEPTH) K tile from SMEM into the row_l rt_32x16 register tile, unswizzling on the way.
Caller must declare K SMEM with shape
row_major[KV_BLOCK * (DEPTH / K_SUB_COLS), K_SUB_COLS] so the
sub-block id linearizes via .tile[K_SUB_ROWS, K_SUB_COLS](id, 0).
load_V
static load_V[layout_dst: TensorLayout, layout_src: TensorLayout, //](mut dst: TileTensor[T, layout_dst, MutExternalOrigin, address_space=AddressSpace.LOCAL], src: TileTensor[T, layout_src, MutAnyOrigin, address_space=AddressSpace.SHARED])
Loads the whole V tile from SMEM into the col_l rt_16x32 register tile via ds_read_tr16_b64_warp transpose-loads.
Each output base tile spans two V_SUB_ROWS-tall SMEM sub-blocks
(top + bot); they are joined into one 8-element MMA fragment.
Caller must declare V SMEM with shape
row_major[KV_BLOCK * (DEPTH / V_SUB_COLS), V_SUB_COLS].
mma_QK
static mma_QK[T_att: DType, layout_att: TensorLayout, layout_k: TensorLayout, layout_q: TensorLayout, //](mut att: TileTensor[T_att, layout_att, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut k: TileTensor[T, layout_k, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut q: TileTensor[T, layout_q, MutExternalOrigin, address_space=AddressSpace.LOCAL])
QK MFMA: att += k @ q^T. K is A (M-outer), Q is B (N-outer).
For each output base tile (n, m):
att[n, m] += sum_k k[n, k] * q[m, k].
mma_PV
static mma_PV[T_o: DType, layout_o: TensorLayout, layout_v: TensorLayout, layout_p: TensorLayout, //](mut o: TileTensor[T_o, layout_o, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut v: TileTensor[T, layout_v, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut p: TileTensor[T, layout_p, MutExternalOrigin, address_space=AddressSpace.LOCAL])
PV MFMA: o += v^T @ p. V is A (K-outer), P is B (K-outer, JIT-cast BF16 from att_block).
For each output base tile (n, m):
o[n, m] += sum_k v[k, n] * p[k, m].
exp2_inplace_range
static exp2_inplace_range[T_att: DType, layout: TensorLayout, //, start: Int, end: Int](mut tile: TileTensor[T_att, layout, MutExternalOrigin, address_space=AddressSpace.LOCAL]) where T_att.is_floating_point()
In-place exp2 over a base-tile-aligned per-lane slice tile[start:end]. start and end must be multiples of the fragment width so the slice maps to whole base tiles. Used to split first / second half of the softmax exp2 across PV MFMAs.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!