Skip to main content

Mojo struct

MhaMmaOp

struct MhaMmaOp[T: DType, config: HKMhaConfig]

Namespace-style struct holding the shape constants, register-tile layouts, and SMEM→register loaders for HKMhaPrefill. All call sites go through static methods on this struct.

Specialized for v_mfma_f32_32x32x16_bf16; the MFMA shape, SMEM sub-block geometry, and per-lane fragment decomposition are hard-coded inside (the rest of the kernel — swizzle, sub-tile layouts, ds_read variants — is wired to this exact MFMA).

Parameters

  • T (DType): Element data type (BF16).
  • config (HKMhaConfig): Shape configuration.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

ATT_BF16_FULL_LAYOUT

comptime ATT_BF16_FULL_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 16), (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 8]()

Full att BF16 tile pre-cast from FP32 (indexed by subtile_idx to feed mma_PV strip-by-strip).

ATT_BF16_SUB_LAYOUT

comptime ATT_BF16_SUB_LAYOUT = row_major[1, (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 8]()

One PV-A subtile (16-row strip of att, BF16).

ATT_LAYOUT

comptime ATT_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 32), (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 16]()

Attention block (QK output, col_l rt_32x32 FP32).

DEPTH

comptime DEPTH = config.depth

FRAG_ELTS

comptime FRAG_ELTS = 8

BF16 elements per lane per MFMA base tile (MMA_M * MMA_K / 64).

K_LAYOUT

comptime K_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 32), (MhaMmaOp[T, config].DEPTH // 16), 8]()

K register tile (whole K, pre-loaded across cluster boundaries).

K_SUB_COLS

comptime K_SUB_COLS = 32

K SMEM sub-block cols.

K_SUB_ROWS

comptime K_SUB_ROWS = 32

K SMEM sub-block rows (two-XOR swizzle).

KV_BLOCK

comptime KV_BLOCK = config.kv_block

MMA_K

comptime MMA_K = 16

MMA_M

comptime MMA_M = 32

MMA_N

comptime MMA_N = 32

O_LAYOUT

comptime O_LAYOUT = row_major[(MhaMmaOp[T, config].DEPTH // 32), (MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), 16]()

Output accumulator (col_l rt_32x32 FP32).

O_T_LAYOUT

comptime O_T_LAYOUT = row_major[(MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), (MhaMmaOp[T, config].DEPTH // 32), 16]()

Output transpose (row_l view of the same storage as O_LAYOUT).

Q_BLOCK_SIZE

comptime Q_BLOCK_SIZE = config.q_block_size

Q_LAYOUT

comptime Q_LAYOUT = row_major[(MhaMmaOp[T, config].Q_BLOCK_SIZE // 32), (MhaMmaOp[T, config].DEPTH // 16), 8]()

Q register tile.

ROWL_HALF_LANES

comptime ROWL_HALF_LANES = 32

Lanes per half-warp in the row_l rt_32x16 decomposition. Lanes [0, 32) own one row, [32, 64) own a shifted col block.

ROWL_STRIDE

comptime ROWL_STRIDE = 8

Per-lane fragment width in the row_l rt_32x16 decomposition (MMA_K // 2): half-warp partitioning packs two strips per base tile.

V_LAYOUT

comptime V_LAYOUT = row_major[(MhaMmaOp[T, config].KV_BLOCK // 16), (MhaMmaOp[T, config].DEPTH // 32), 8]()

V register tile.

V_SUB_COLS

comptime V_SUB_COLS = 32

V SMEM sub-block cols.

V_SUB_ROWS

comptime V_SUB_ROWS = 8

V SMEM sub-block rows (identity swizzle).

Methods

load_K

static load_K[layout_dst: TensorLayout, layout_src: TensorLayout, //](mut dst: TileTensor[T, layout_dst, MutExternalOrigin, address_space=AddressSpace.LOCAL], src: TileTensor[T, layout_src, MutAnyOrigin, address_space=AddressSpace.SHARED])

Loads the whole (KV_BLOCK, DEPTH) K tile from SMEM into the row_l rt_32x16 register tile, unswizzling on the way.

Caller must declare K SMEM with shape row_major[KV_BLOCK * (DEPTH / K_SUB_COLS), K_SUB_COLS] so the sub-block id linearizes via .tile[K_SUB_ROWS, K_SUB_COLS](id, 0).

load_V

static load_V[layout_dst: TensorLayout, layout_src: TensorLayout, //](mut dst: TileTensor[T, layout_dst, MutExternalOrigin, address_space=AddressSpace.LOCAL], src: TileTensor[T, layout_src, MutAnyOrigin, address_space=AddressSpace.SHARED])

Loads the whole V tile from SMEM into the col_l rt_16x32 register tile via ds_read_tr16_b64_warp transpose-loads.

Each output base tile spans two V_SUB_ROWS-tall SMEM sub-blocks (top + bot); they are joined into one 8-element MMA fragment. Caller must declare V SMEM with shape row_major[KV_BLOCK * (DEPTH / V_SUB_COLS), V_SUB_COLS].

mma_QK

static mma_QK[T_att: DType, layout_att: TensorLayout, layout_k: TensorLayout, layout_q: TensorLayout, //](mut att: TileTensor[T_att, layout_att, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut k: TileTensor[T, layout_k, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut q: TileTensor[T, layout_q, MutExternalOrigin, address_space=AddressSpace.LOCAL])

QK MFMA: att += k @ q^T. K is A (M-outer), Q is B (N-outer).

For each output base tile (n, m): att[n, m] += sum_k k[n, k] * q[m, k].

mma_PV

static mma_PV[T_o: DType, layout_o: TensorLayout, layout_v: TensorLayout, layout_p: TensorLayout, //](mut o: TileTensor[T_o, layout_o, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut v: TileTensor[T, layout_v, MutExternalOrigin, address_space=AddressSpace.LOCAL], mut p: TileTensor[T, layout_p, MutExternalOrigin, address_space=AddressSpace.LOCAL])

PV MFMA: o += v^T @ p. V is A (K-outer), P is B (K-outer, JIT-cast BF16 from att_block).

For each output base tile (n, m): o[n, m] += sum_k v[k, n] * p[k, m].

exp2_inplace_range

static exp2_inplace_range[T_att: DType, layout: TensorLayout, //, start: Int, end: Int](mut tile: TileTensor[T_att, layout, MutExternalOrigin, address_space=AddressSpace.LOCAL]) where T_att.is_floating_point()

In-place exp2 over a base-tile-aligned per-lane slice tile[start:end]. start and end must be multiples of the fragment width so the slice maps to whole base tiles. Used to split first / second half of the softmax exp2 across PV MFMAs.