For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
mha_prefill_v2
MhaPrefillV2 — long-context BF16 MHA prefill for AMD MI355X (gfx950).
run interleaves QK MFMA, PV MFMA, softmax + rescale, and
K/V DMA across an explicit cluster schedule so each work class
overlaps the others' latency. 8 warps per block; each warp owns a
Q_BLOCK_SIZE-row stripe of Q.
Cluster schedule
The main loop runs 8 clusters per iteration and advances j by 2,
so each iter processes two K/V tiles. Each cluster ends in a bare
s_barrier.
| Cluster | Work |
|---|---|
| C0 | QK[j-2] + tail softmax of tile (j-3) |
| C1 | DMA K[j] + LDS→register V[j-3] + mask (j-2)¹ |
| C2 | PV[j-3] strip-interleaved with partial softmax(j-2) |
| C3 | DMA V[j-1] + LDS→register K[j-1] |
| C4 | QK[j-1] + tail softmax of tile (j-2) |
| C5 | DMA K[j+1] + LDS→register V[j-2] + mask (j-1) |
| C6 | PV[j-2] strip-interleaved with partial softmax(j-1) |
| C7 | DMA V[j] + LDS→register K[j] |
¹ Non-Causal masks only. CausalMask comptime-elides the C1 mask
call because the max_num_tiles cap leaves tile (j-2) naturally
fully unmasked.
Whole-tile K is pre-loaded one iteration ahead into the persistent
k_reg, so the QK clusters (C0/C4) contain MFMAs + VALU only — no
in-cluster ds_read. The prologue primes the pipeline and runs
QK[0] + partial softmax; the 13-cluster epilogue drains the final
four tiles N-4..N-1, with whole-V PV (no strip split) and an
unconditional normalizer rescale before the o / norm_vec divide.
Key design choices
-
Whole-tile K pre-load with consumer-side waitcnt drains. Each MFMA-consumer helper opens with
s_waitcnt[lgkmcnt=0]()so SIInsertWaitcnts treats the cluster as a bracket reset and the per- consumerlgkmcntstaircase collapses. -
Kernel-scope BF16 P-cache. Each softmax bulk-casts FP32 att to one persistent
att_block_bf16register tile reused by the subsequent PV (avoids LLVM rematerializing the cast per use site). -
Lazy rescale (
RESCALE_THRESHOLD=8). In C2/C6, when the running max grows by more than 8 log2 units,o_reg *= scale_vecfires between PV strip 0 and strips 1-3 — strips 1-3 then contribute at the old scale into an already-rescaled accumulator. The 8 log2 cap bounds the inconsistency. Whenrv_all_belowreports no lane exceeded the threshold, the rescale is skipped andscale_vecis reset to 1 (so the epilogue's unconditional multiply stays identity — see below). The epilogue's tail softmax appliesnorm_vec *= scale_vecunconditionally; the initialized-to-1 + reset-to-1-on-skip invariant guarantees this is identity unless a rescale fired in the last C2/C6. -
Mask placement. Tiles
0(prologue),(j - 2)for each main-loop iter (C1, non-Causal masks only — see below),(j - 1)(C5, all masks), andN - 3, N - 2, N - 1(epilogue). ForCausalMaskthemax_num_tilescap guarantees odd-numbered K tiles in the main-loop range are naturally fully unmasked, so the C0/C2 path skips the mask call. Non-causal masks (SlidingWindow/Chunked) cannot rely on that cap, so the C1 site applies the mask toatt_block_1(= QK[j-2]) before C2's partial softmax reads it. -
Output transpose.
col_l → row_lis a zero-cost re-tag of the same register storage — no cross-lane permute, no data motion. -
GQA-aware head remap.
head_idxis(block_x % GROUP) * NUM_KV_HEADS + (block_x / GROUP)— the transpose over the(NUM_KV_HEADS, GROUP)rectangle — so adjacent blocks visit different KV heads across CUs/XCDs. Bijective for anyNUM_HEADS == GROUP * NUM_KV_HEADS; reduces to identity at MHA (GROUP=1) and MQA (NUM_KV_HEADS=1).
The cluster decomposition and overlap pattern are inspired by the reference attention kernel.
Structs
-
MhaPrefillV2: 8-warp MHA forward kernel parameterized byMhaConfigV2.
Functions
-
mha_prefill_v2: Host launcher forMhaPrefillV2. -
mha_prefill_v2_ragged: Host launcher for raggedMhaPrefillV2prefill.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!