For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MlaMmaOp
struct MlaMmaOp[T: DType, config: MhaConfigV2]
MLA MMA op for MlaPrefillV2 / MlaPrefillV2Core.
Extends MhaMmaOp (re-exporting its shape constants / register-tile
layouts and delegating _swizzle_K_sub / mma_QK / mma_PV /
exp2_inplace_range) with the MLA FP8 32x32x64 fragment loaders β
load_K_frag, precompute_v_lane_base, load_V_from_lane_base,
load_V_frag β that stream K/V through a rotating in-register band
(the reference never materializes the whole V tile).
The split isolates these FP8 extensions from MHA: the
MhaPrefillV2 BF16 prefill path uses MhaMmaOp directly, so an edit
to MLA's FP8 fragment loaders can no longer perturb the MHA kernel's
BF16 load_K / load_V schedule (interleaving FP8 branches into the
shared whole-tile loaders opens a K/V DMA-vs-ds_read vmcnt race
in MhaPrefillV2).
MLA never calls the whole-tile load_K / load_V; it consumes
K/V fragment-at-a-time via the *_frag loaders. The whole-tile
loaders therefore stay on MhaMmaOp only.
Parametersβ
- βT (
DType): Element data type (FP8 e4m3 for the MLA prefill path). - βconfig (
MhaConfigV2): Shape configuration (MhaConfigV2, viaMlaConfigV2.mha()).
Implemented traitsβ
comptime membersβ
ATT_BF16_FULL_LAYOUTβ
comptime ATT_BF16_FULL_LAYOUT = MhaMmaOp[T, config].ATT_BF16_FULL_LAYOUT
ATT_BF16_SUB_LAYOUTβ
comptime ATT_BF16_SUB_LAYOUT = MhaMmaOp[T, config].ATT_BF16_SUB_LAYOUT
ATT_LAYOUTβ
comptime ATT_LAYOUT = MhaMmaOp[T, config].ATT_LAYOUT
DEPTHβ
comptime DEPTH = MhaMmaOp[T, config].DEPTH
FP8_MMA_K_128β
comptime FP8_MMA_K_128 = MhaMmaOp[T, config].FP8_MMA_K_128
FRAG_ELTSβ
comptime FRAG_ELTS = MhaMmaOp[T, config].FRAG_ELTS
K_LAYOUTβ
comptime K_LAYOUT = MhaMmaOp[T, config].K_LAYOUT
K_SUB_COLSβ
comptime K_SUB_COLS = MhaMmaOp[T, config].K_SUB_COLS
K_SUB_ROWSβ
comptime K_SUB_ROWS = MhaMmaOp[T, config].K_SUB_ROWS
KV_BLOCKβ
comptime KV_BLOCK = MhaMmaOp[T, config].KV_BLOCK
MMA_Kβ
comptime MMA_K = MhaMmaOp[T, config].MMA_K
MMA_Mβ
comptime MMA_M = MhaMmaOp[T, config].MMA_M
MMA_Nβ
comptime MMA_N = MhaMmaOp[T, config].MMA_N
O_LAYOUTβ
comptime O_LAYOUT = MhaMmaOp[T, config].O_LAYOUT
O_T_LAYOUTβ
comptime O_T_LAYOUT = MhaMmaOp[T, config].O_T_LAYOUT
PV_A_FRAG_ELTSβ
comptime PV_A_FRAG_ELTS = MhaMmaOp[T, config].PV_A_FRAG_ELTS
Q_BLOCK_SIZEβ
comptime Q_BLOCK_SIZE = MhaMmaOp[T, config].Q_BLOCK_SIZE
Q_LAYOUTβ
comptime Q_LAYOUT = MhaMmaOp[T, config].Q_LAYOUT
ROWL_HALF_LANESβ
comptime ROWL_HALF_LANES = MhaMmaOp[T, config].ROWL_HALF_LANES
ROWL_STRIDEβ
comptime ROWL_STRIDE = MhaMmaOp[T, config].ROWL_STRIDE
V_LAYOUTβ
comptime V_LAYOUT = MhaMmaOp[T, config].V_LAYOUT
V_SUB_COLSβ
comptime V_SUB_COLS = MhaMmaOp[T, config].V_SUB_COLS
V_SUB_ROWSβ
comptime V_SUB_ROWS = MhaMmaOp[T, config].V_SUB_ROWS
Methodsβ
mma_QKβ
static def mma_QK[T_att: DType, layout_att: TensorLayout, layout_k: TensorLayout, layout_q: TensorLayout, //](mut att: TileTensor[T_att, layout_att, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut k: TileTensor[T, layout_k, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut q: TileTensor[T, layout_q, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
QK MFMA β delegates to MhaMmaOp.mma_QK (body identical).
mma_PVβ
static def mma_PV[T_o: DType, layout_o: TensorLayout, layout_v: TensorLayout, layout_p: TensorLayout, //](mut o: TileTensor[T_o, layout_o, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut v: TileTensor[T, layout_v, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut p: TileTensor[T, layout_p, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
PV MFMA β delegates to MhaMmaOp.mma_PV (body identical).
exp2_inplace_rangeβ
static def exp2_inplace_range[T_att: DType, layout: TensorLayout, //, start: Int, end: Int](mut tile: TileTensor[T_att, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])
exp2 over a slice β delegates to MhaMmaOp.exp2_inplace_range.
load_K_fragβ
static def load_K_frag[sub_id: Int](src: TileTensor[T, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> SIMD[T, Self.FRAG_ELTS]
Loads ONE K MFMA fragment (sub_id) from the K SMEM sub-view src and returns it as a SIMD value β the single-fragment factoring of the FP8 32x32x64 K-load inner loop.
Returning a SIMD value (rather than writing a register tile)
lets the caller stream fragments through a rotating in-register
band where each slot is a plain SSA value β sidestepping the
strided-register-sub-view write that lands on the wrong VGPRs
when a sub-tile of a larger reg_alloc is the destination at a
non-zero offset. The fragment feeds gpu_mma directly (the QK
A-operand).
sub_id is the K SMEM sub-block index; it enters only as a
comptime ds_read offset: immediate. The two per-lane swizzled
bases are loop-invariant in src.ptr, so across a caller's
unrolled per-fragment stream LLVM CSEs them to a single base
register pair (the reference's v226 single-base K pattern).
Constraints:
FP8 32x32x64 path only (MMA_K == 64).
Returns:
precompute_v_lane_baseβ
static def precompute_v_lane_base[origin: Origin[mut=origin.mut], //, v_full_v227: Bool = False, v227_layout: Bool = False](v_slot_ptr: UnsafePointer[Scalar[T], origin, address_space=AddressSpace.SHARED]) -> UnsafePointer[Scalar[T], origin, address_space=AddressSpace.SHARED]
Computes the per-lane V LDS base pointer for the FP8 32x32x64 path ONCE per V SMEM slot (caller passes v_smem_<stage>.ptr). Hoisting this base out of the per-fragment readout collapses the per-call lane offset across all invocations from the same slot into ONE shared materialization β mirroring the reference's v227 base carried across the entire main loop. Caller threads the returned pointer into load_V_frag. Origin is propagated from the input slot so the returned pointer carries the slot's lifetime/mutability annotation.
Parameters:
- βv_full_v227 (
Bool): Referencev227V adapter per-lane READ base (Bool). Default False β byte-identical. When True, replaces the per-lane base ENTIRELY with the reference's exactv227formula (the WHOLE read map), a per-tr8-cycle bank-quadrant bijection that eliminates the V-transpose LDS bank conflict. This is HALF of the adapterRβ the caller MUST also passv_full_v227=Truetoload_V_frag(the faithful readout cell) AND run the matching producer (SubTileLoaderLDS_st_8x32[v_full_v227=True]/MlaPrefillV2Core._dma_v[v_full_v227=True], theW); the three compose to the standard PV fragment (numerically equivalent). FP8 32x32x64 only. The default-on V LDS adapter forMlaPrefillV2; production MLA passesv_full_v227=False.-D v_full_v227. - βv227_layout (
Bool): Spell thev_full_v227per-lane READ base via CuTe Layout Algebra (crd2idxover a per-bitCoord) instead of the hand-rolled runtime bit arithmetic (Bool, only consulted whenv_full_v227is True). SAME mapping, different spelling β thev227base is bit-LINEAR over the bit-decomposed lane, so it iscrd2idx(lane, Coord(2,2,2,2,2,2), Coord(8, 0x80, 0x820, 0x1040, 16, 0x410))(the 2-bit field((v0>>2)&3)* 0x820splits to bit20x820 + bit30x1040). Mirrors the WRITE side (SubTileLoaderLDS_st_8x32[_v227_layout]); this is the no-underscore parameter form of the WRITE side's struct field_v227_layout, both driven by-D v227_layout. Numerically equivalent.-D v227_layout.
Returns:
UnsafePointer[Scalar[T], origin, address_space=AddressSpace.SHARED]
load_V_from_lane_baseβ
static def load_V_from_lane_base[layout_dst: TensorLayout](mut dst: TileTensor[T, layout_dst, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], v_lane_base: UnsafePointer[Scalar[T], address_space=AddressSpace.SHARED])
FP8 32x32x64 V load consuming a pre-computed per-lane base pointer (hoisted by the caller). Each call site adds ONLY comptime per-cell offsets to v_lane_base β no per-call per-lane base recomputation, no src.ptr indirection.
Collapses the 27 distinct V ds_read base VGPRs at KV=128 (one
per V-load Γ slot toggle) toward the reference's ~3-base pattern
(one per SMEM slot, carried across all calls from that slot).
Codegen-equivalent to the equivalent in-call recomputation
on the K=128 fast path (the AMDGPU backend recovers the same
ds_read offset:imm) but makes the kernel-level hoist explicit at
the source level.
load_V_fragβ
static def load_V_frag[i_strip: Int, j_depth: Int, v_full_v227: Bool = False](v_lane_base: UnsafePointer[Scalar[T], address_space=AddressSpace.SHARED]) -> SIMD[T, Self.FRAG_ELTS]
Loads ONE V MFMA fragment (i_strip, j_depth) from the pre-computed per-lane V LDS base and returns it as a SIMD value β the single-fragment factoring of load_V_from_lane_base's FP8 32x32x64 inner (i, j) cell (the 4 paired-lane ds_read_tr8_b64 joined into one SIMD[FP8, 32]).
Returning a SIMD value (rather than writing a register tile) is
the V counterpart of load_K_frag β it lets the caller stream V
fragments through the SAME rotating in-register band K streamed
through (K and V have disjoint lifetimes: K is consumed in the QK
MFMAs before softmax, V in the PV MFMAs after, so the band's
registers are free for V by PV time). Each slot is then a plain
SSA value, sidestepping the strided-register-sub-view write that
lands on the wrong VGPRs at a non-zero offset, and avoids
materializing the whole 64-VGPR V_LAYOUT tile held live across
the softmax/exp/FP8 clusters (the reference never materializes V
β it transpose-reads V fragment-at-a-time through its reused
v[28:59] band). The fragment feeds mma_PV / gpu_mma
directly (the PV A-operand).
i_strip is the K-direction base tile (i_strip*MMA_K keys per
strip; 0 .. KV_BLOCK/MMA_K); j_depth is the depth tile
(j_depth*MMA_N cols; 0 .. DEPTH/MMA_N). Both enter only as
comptime ds_read_tr8_b64 offset: immediates, so across a
caller's unrolled per-fragment stream LLVM CSEs v_lane_base to
a single base-register set (the reference's v227 single-base V
pattern) β the same hoist load_V_from_lane_base already relies
on.
Constraints:
FP8 32x32x64 path only (MMA_K == 64).
Parameters:
- βi_strip (
Int): K-direction base tile index. - βj_depth (
Int): Depth tile index. - βv_full_v227 (
Bool): Referencev227V adapter READ cell (Bool). Default False β OURst_8x32cell decode (byte-identical). When True, uses the faithful reference readout cell offseti_strip*0x2080 + j_depth*0x20 + r*0x100(r = subread 0..3) on top of thev227per-lane base (precompute_v_lane_base[v_full_v227=True]). This is theRof the adapterWβRpair; the producer (SubTileLoaderLDS_st_8x32[v_full_v227=True]/MlaPrefillV2Core._dma_v[v_full_v227=True]) MUST setv_full_v227=Truetoo, or V scrambles. The two compose to the standard PV fragment, bank-conflict-free (v227's per-tr8-cycle bank-quadrant bijection). FP8 32x32x64 only.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!