IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mha_prefill_v2

MhaPrefillV2 — long-context BF16 MHA prefill for AMD MI355X (gfx950).

run interleaves QK MFMA, PV MFMA, softmax + rescale, and K/V DMA across an explicit cluster schedule so each work class overlaps the others' latency. 8 warps per block; each warp owns a Q_BLOCK_SIZE-row stripe of Q.

Cluster schedule

The main loop runs 8 clusters per iteration and advances j by 2, so each iter processes two K/V tiles. Each cluster ends in a bare s_barrier.

ClusterWork
C0QK[j-2] + tail softmax of tile (j-3)
C1DMA K[j] + LDS→register V[j-3] + mask (j-2)¹
C2PV[j-3] strip-interleaved with partial softmax(j-2)
C3DMA V[j-1] + LDS→register K[j-1]
C4QK[j-1] + tail softmax of tile (j-2)
C5DMA K[j+1] + LDS→register V[j-2] + mask (j-1)
C6PV[j-2] strip-interleaved with partial softmax(j-1)
C7DMA V[j] + LDS→register K[j]

¹ Non-Causal masks only. CausalMask comptime-elides the C1 mask call because the max_num_tiles cap leaves tile (j-2) naturally fully unmasked.

Whole-tile K is pre-loaded one iteration ahead into the persistent k_reg, so the QK clusters (C0/C4) contain MFMAs + VALU only — no in-cluster ds_read. The prologue primes the pipeline and runs QK[0] + partial softmax; the 13-cluster epilogue drains the final four tiles N-4..N-1, with whole-V PV (no strip split) and an unconditional normalizer rescale before the o / norm_vec divide.

Key design choices

  • Whole-tile K pre-load with consumer-side waitcnt drains. Each MFMA-consumer helper opens with s_waitcnt[lgkmcnt=0]() so SIInsertWaitcnts treats the cluster as a bracket reset and the per- consumer lgkmcnt staircase collapses.

  • Kernel-scope BF16 P-cache. Each softmax bulk-casts FP32 att to one persistent att_block_bf16 register tile reused by the subsequent PV (avoids LLVM rematerializing the cast per use site).

  • Lazy rescale (RESCALE_THRESHOLD=8). In C2/C6, when the running max grows by more than 8 log2 units, o_reg *= scale_vec fires between PV strip 0 and strips 1-3 — strips 1-3 then contribute at the old scale into an already-rescaled accumulator. The 8 log2 cap bounds the inconsistency. When rv_all_below reports no lane exceeded the threshold, the rescale is skipped and scale_vec is reset to 1 (so the epilogue's unconditional multiply stays identity — see below). The epilogue's tail softmax applies norm_vec *= scale_vec unconditionally; the initialized-to-1 + reset-to-1-on-skip invariant guarantees this is identity unless a rescale fired in the last C2/C6.

  • Mask placement. Tiles 0 (prologue), (j - 2) for each main-loop iter (C1, non-Causal masks only — see below), (j - 1) (C5, all masks), and N - 3, N - 2, N - 1 (epilogue). For CausalMask the max_num_tiles cap guarantees odd-numbered K tiles in the main-loop range are naturally fully unmasked, so the C0/C2 path skips the mask call. Non-causal masks (SlidingWindow/Chunked) cannot rely on that cap, so the C1 site applies the mask to att_block_1 (= QK[j-2]) before C2's partial softmax reads it.

  • Output transpose. col_l → row_l is a zero-cost re-tag of the same register storage — no cross-lane permute, no data motion.

  • GQA-aware head remap. head_idx is (block_x % GROUP) * NUM_KV_HEADS + (block_x / GROUP) — the transpose over the (NUM_KV_HEADS, GROUP) rectangle — so adjacent blocks visit different KV heads across CUs/XCDs. Bijective for any NUM_HEADS == GROUP * NUM_KV_HEADS; reduces to identity at MHA (GROUP=1) and MQA (NUM_KV_HEADS=1).

The cluster decomposition and overlap pattern are inspired by the reference attention kernel.

Structs

Functions