For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
mla_prefill_v2
MlaPrefillV2 β fresh, from-scratch port of the reference MLA-prefill INTEGRATED inner-loop architecture for AMD MI355X (gfx950).
This is a NEW kernel struct (sibling of MlaPrefillV2Core, not a retrofit)
that lays out the reference mla_pfl_qh192_vh128_m32x8_n128x1 register
structure directly:
- 1 wave / EU (
llvm.amdgpu-waves-per-eu = "1,1") β the 256-VGPR budget is what makes the reference footprint fit; the default 2-wave cap (128 VGPR/wave) spills the FP32 score + O accumulators. - Resident Q β loaded once per work-tile, held in registers across every KV block (24 VGPR for FP8 d_qk=192).
- Single 64-VGPR FP32 score tile = the QK MFMA accumulator; softmax
runs in place on it; the FP8 P operand collapses 4:1 IN PLACE into the
score tile's own low quarter (
v_cvt_pk_fp8_f32 op_sel:[0,0,1]). No separatep_block, no score double-buffer. - 64-VGPR FP32 O accumulator, eagerly VALU-rescaled per the online softmax recurrence.
- Streamed K band; streamed V band. K is streamed from LDS into a function-local band, consumed by the QK MFMAs, then freed. V is then streamed fragment-at-a-time through the band K vacates (disjoint lifetimes) β the reference lean layout β rather than materialized as a whole register tile.
The reference-exact inner loopβ
_attend_exact lays each KV tile out as 6 barrier-delimited
clusters (7 bare _s_barrier_raw) matching the reference's label_01D6
boundaries β
C_QK -> C_V_PREFETCH -> C_SOFTMAX_MAX -> C_EXP/rescale -> C_FP8_PACK ->
C_PV β and each cluster comment cites the reference asm line it mirrors.
The two warp-groups (waves 0-3 / waves 4-7) run that body phase-shifted
via an asymmetric +4 prologue _s_barrier_raw() stagger, with a
work-split K/V DMA (waves 0-3 produce K, waves 4-7 produce V) into two
disjoint LDS ring regions (K depth-2, V depth-4 β the reference V region
is the wider of the two). The shared math is reproduced in-file without
editing any shared file.
The prologue stagger (see the prologue keystone + tail-compensation
comments in _attend_exact):
-D exact_stagger(default =persistent) β the EXACT reference two-half-body discipline: the upper half pays +4 at the prologue and the lower half pays a matching +4 at the work-item TAIL (the referencelabel_06B4/label_1A51), EVERY work-item. Per-work-item barrier totals stay EQUAL (no +4N accumulation, no deadlock) while the +4 phase skew RE-FORMS each work-item -> a steady skew conserved across the CU's whole work stream. Off (the static-grid default) = the +4 fires only on wi0; work-items 1..N run in lockstep (a no-op at one work-item/CU).-D v_qktail(default = NOTpersistent) β prefetch the first V band fragments into the QK-tail. A win where registers have headroom (static / batch>1); disabled underpersistentbecause the held-across-softmax band spills at the 256-VGPR ceiling there.
The cadence levers that reproduce the reference instruction schedule (all
unconditional; each pinned by a mask-0 schedule_barrier, which fixes a
hand-specified order at codegen β program order alone is re-clustered by
the IGLP solver):
- Non-materialized V band (lean ~210-VGPR layout): V is streamed fragment-at-a-time through the band K vacates (disjoint lifetimes).
- 4-slot rotating V band in C_PV (3 slots in flight): V reads land ahead
of their consuming PV MFMA so the per-MFMA drain is soft
lgkmcnt(8)(the reference load/MFMA C_PV cadence, ref asm L744-785). - 4-ahead K ring in C_QK pinned by the mask-0 fence (breaks the K-band
WAR toward the reference soft
lgkmcnt(4)). - Next-tile K/V prefetch issued in C_QK (ref asm L356-410).
- Resident Q staged DRAM->LDS->VGPR (the reference Q@0x0 region).
Correctness strategy β reuse the verified MLA-prefill mathβ
The MLA-prefill MATH (QK with nope d=128 + rope d=64; FlashAttention-2
online softmax with running max/sum + cross-tile rescale; in-place FP8
P collapse; PV accumulate; normalize + store; causal / null mask) lives
in mla_components.mojo's MlaPrefillV2Core[config] (FP32-scores
path), trimmed to exactly the closure this kernel consumes. Rather than
re-deriving it, _attend_exact reuses those BARRIER-FREE numeric
primitives (the OnlineSoftmax recurrence, MhaMmaOp MFMA/exp/cast
helpers, _qk_collapse_inplace, the K/V LDS loaders, _store_o_to_gmem)
and _MlaKDmaPair for the K DMA β but emits the cluster cadence + the
QK/PV MFMA streams in-file, so the reference's bare s_barrier boundaries
are NOT fragmented by the delegated helpers' own lgkmcnt(0) drains / IGLP
fences. This file owns (a) the waves_per_eu=1,1 kernel entry, (b) the
single reference-faithful inner loop, and (c) the host launcher.
Because the _FP32_SOFTMAX_SCORES gate (FP8 + KV>=128 + 32x32x64) is the
default-True path for the FP8 KV=128 target shape, every reused
primitive exercises the exact codegen this kernel ships.
Structsβ
- β
MlaPrefillV2: Fresh single-schedule port of the reference integrated MLA-prefill inner loop for gfx950. 1 wave / EU; resident Q; single FP32 score tile with in-place FP8 P collapse; shared K/V band; 64-VGPR FP32 O accumulator with eager rescale; reference work-split K/V DMA + deep even wave-spec stagger over a 160 KB LDS.
Functionsβ
- β
mla_prefill_v2_ragged: Host launcher for ragged MLA prefill.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!