IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mla_prefill_v2

MlaPrefillV2 β€” fresh, from-scratch port of the reference MLA-prefill INTEGRATED inner-loop architecture for AMD MI355X (gfx950).

This is a NEW kernel struct (sibling of MlaPrefillV2Core, not a retrofit) that lays out the reference mla_pfl_qh192_vh128_m32x8_n128x1 register structure directly:

  • 1 wave / EU (llvm.amdgpu-waves-per-eu = "1,1") β€” the 256-VGPR budget is what makes the reference footprint fit; the default 2-wave cap (128 VGPR/wave) spills the FP32 score + O accumulators.
  • Resident Q β€” loaded once per work-tile, held in registers across every KV block (24 VGPR for FP8 d_qk=192).
  • Single 64-VGPR FP32 score tile = the QK MFMA accumulator; softmax runs in place on it; the FP8 P operand collapses 4:1 IN PLACE into the score tile's own low quarter (v_cvt_pk_fp8_f32 op_sel:[0,0,1]). No separate p_block, no score double-buffer.
  • 64-VGPR FP32 O accumulator, eagerly VALU-rescaled per the online softmax recurrence.
  • Streamed K band; streamed V band. K is streamed from LDS into a function-local band, consumed by the QK MFMAs, then freed. V is then streamed fragment-at-a-time through the band K vacates (disjoint lifetimes) β€” the reference lean layout β€” rather than materialized as a whole register tile.

The reference-exact inner loop​

_attend_exact lays each KV tile out as 6 barrier-delimited clusters (7 bare _s_barrier_raw) matching the reference's label_01D6 boundaries β€” C_QK -> C_V_PREFETCH -> C_SOFTMAX_MAX -> C_EXP/rescale -> C_FP8_PACK -> C_PV β€” and each cluster comment cites the reference asm line it mirrors. The two warp-groups (waves 0-3 / waves 4-7) run that body phase-shifted via an asymmetric +4 prologue _s_barrier_raw() stagger, with a work-split K/V DMA (waves 0-3 produce K, waves 4-7 produce V) into two disjoint LDS ring regions (K depth-2, V depth-4 β€” the reference V region is the wider of the two). The shared math is reproduced in-file without editing any shared file.

The prologue stagger (see the prologue keystone + tail-compensation comments in _attend_exact):

  • -D exact_stagger (default = persistent) β€” the EXACT reference two-half-body discipline: the upper half pays +4 at the prologue and the lower half pays a matching +4 at the work-item TAIL (the reference label_06B4 / label_1A51), EVERY work-item. Per-work-item barrier totals stay EQUAL (no +4N accumulation, no deadlock) while the +4 phase skew RE-FORMS each work-item -> a steady skew conserved across the CU's whole work stream. Off (the static-grid default) = the +4 fires only on wi0; work-items 1..N run in lockstep (a no-op at one work-item/CU).
  • -D v_qktail (default = NOT persistent) β€” prefetch the first V band fragments into the QK-tail. A win where registers have headroom (static / batch>1); disabled under persistent because the held-across-softmax band spills at the 256-VGPR ceiling there.

The cadence levers that reproduce the reference instruction schedule (all unconditional; each pinned by a mask-0 schedule_barrier, which fixes a hand-specified order at codegen β€” program order alone is re-clustered by the IGLP solver):

  • Non-materialized V band (lean ~210-VGPR layout): V is streamed fragment-at-a-time through the band K vacates (disjoint lifetimes).
  • 4-slot rotating V band in C_PV (3 slots in flight): V reads land ahead of their consuming PV MFMA so the per-MFMA drain is soft lgkmcnt(8) (the reference load/MFMA C_PV cadence, ref asm L744-785).
  • 4-ahead K ring in C_QK pinned by the mask-0 fence (breaks the K-band WAR toward the reference soft lgkmcnt(4)).
  • Next-tile K/V prefetch issued in C_QK (ref asm L356-410).
  • Resident Q staged DRAM->LDS->VGPR (the reference Q@0x0 region).

Correctness strategy β€” reuse the verified MLA-prefill math​

The MLA-prefill MATH (QK with nope d=128 + rope d=64; FlashAttention-2 online softmax with running max/sum + cross-tile rescale; in-place FP8 P collapse; PV accumulate; normalize + store; causal / null mask) lives in mla_components.mojo's MlaPrefillV2Core[config] (FP32-scores path), trimmed to exactly the closure this kernel consumes. Rather than re-deriving it, _attend_exact reuses those BARRIER-FREE numeric primitives (the OnlineSoftmax recurrence, MhaMmaOp MFMA/exp/cast helpers, _qk_collapse_inplace, the K/V LDS loaders, _store_o_to_gmem) and _MlaKDmaPair for the K DMA β€” but emits the cluster cadence + the QK/PV MFMA streams in-file, so the reference's bare s_barrier boundaries are NOT fragmented by the delegated helpers' own lgkmcnt(0) drains / IGLP fences. This file owns (a) the waves_per_eu=1,1 kernel entry, (b) the single reference-faithful inner loop, and (c) the host launcher.

Because the _FP32_SOFTMAX_SCORES gate (FP8 + KV>=128 + 32x32x64) is the default-True path for the FP8 KV=128 target shape, every reused primitive exercises the exact codegen this kernel ships.

Structs​

  • ​MlaPrefillV2: Fresh single-schedule port of the reference integrated MLA-prefill inner loop for gfx950. 1 wave / EU; resident Q; single FP32 score tile with in-place FP8 P collapse; shared K/V band; 64-VGPR FP32 O accumulator with eager rescale; reference work-split K/V DMA + deep even wave-spec stagger over a 160 KB LDS.

Functions​