Mojo module
hk_mha_prefill
HKMhaPrefill — long-context BF16 MHA prefill for AMD MI355X (gfx950).
run interleaves QK MFMA, PV MFMA, softmax + rescale, and
K/V DMA across an explicit cluster schedule so each work class
overlaps the others' latency. 8 warps per block; each warp owns a
Q_BLOCK_SIZE-row stripe of Q.
Cluster schedule
The main loop runs 8 clusters per iteration and advances j by 2,
so each iter processes two K/V tiles. Each cluster ends in a bare
s_barrier.
| Cluster | Work |
|---|---|
| C0 | QK[j-2] + tail softmax of tile (j-3) |
| C1 | DMA K[j] + LDS→register V[j-3] |
| C2 | PV[j-3] strip-interleaved with partial softmax(j-2) |
| C3 | DMA V[j-1] + LDS→register K[j-1] |
| C4 | QK[j-1] + tail softmax of tile (j-2) |
| C5 | DMA K[j+1] + LDS→register V[j-2] + causal mask |
| C6 | PV[j-2] strip-interleaved with partial softmax(j-1) |
| C7 | DMA V[j] + LDS→register K[j] |
Whole-tile K is pre-loaded one iteration ahead into the persistent
k_reg, so the QK clusters (C0/C4) contain MFMAs + VALU only — no
in-cluster ds_read. The prologue primes the pipeline and runs
QK[0] + partial softmax; the 13-cluster epilogue drains the final
four tiles N-4..N-1, with whole-V PV (no strip split) and an
unconditional normalizer rescale before the o / norm_vec divide.
Key design choices
-
Whole-tile K pre-load with consumer-side waitcnt drains. Each MFMA-consumer helper opens with
s_waitcnt[lgkmcnt=0]()so SIInsertWaitcnts treats the cluster as a bracket reset and the per- consumerlgkmcntstaircase collapses (see KBpatterns/amd-explicit-lgkmcnt-drain-consumer-cluster). -
Kernel-scope BF16 P-cache. Each softmax bulk-casts FP32 att to one persistent
att_block_bf16register tile reused by the subsequent PV (see KBknown-limitations/llvm-amdgpu-cast-rematerialization). -
Lazy rescale (
RESCALE_THRESHOLD=8). In C2/C6, when the running max grows by more than 8 log2 units,o_reg *= scale_vecfires between PV strip 0 and strips 1-3 — strips 1-3 then contribute at the old scale into an already-rescaled accumulator. The 8 log2 cap bounds the inconsistency. Whenrv_all_belowreports no lane exceeded the threshold, the rescale is skipped entirely. The epilogue rescales unconditionally withscale_vecinitialized to ones, so the multiply is identity when no rescale ever fired. -
Causal mask placement. Tiles
0(prologue),(j - 1)for each main-loop iter (C5), andN - 3, N - 2, N - 1(epilogue). Tiles1, 3, …in the main-loop range are unmasked by design; themax_num_tilescap guarantees those positions are naturally fully unmasked. -
Output transpose.
col_l → row_lis a zero-cost re-tag of the same register storage — no cross-lane permute, no data motion. -
GQA-aware head remap.
head_idxis(block_x % GROUP) * NUM_KV_HEADS + (block_x / GROUP)— the transpose over the(NUM_KV_HEADS, GROUP)rectangle — so adjacent blocks visit different KV heads across CUs/XCDs. Bijective for anyNUM_HEADS == GROUP * NUM_KV_HEADS; reduces to identity at MHA (GROUP=1) and MQA (NUM_KV_HEADS=1).
The cluster decomposition and overlap pattern are inspired by the HipKittens project's attention kernel (https://github.com/HazyResearch/HipKittens).
Structs
-
HKMhaPrefill: 8-warp MHA forward kernel parameterized byHKMhaConfig.
Functions
-
hk_mha_prefill: Host launcher forHKMhaPrefill.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!