Skip to main content

Mojo module

hk_mha_prefill

HKMhaPrefill — long-context BF16 MHA prefill for AMD MI355X (gfx950).

run interleaves QK MFMA, PV MFMA, softmax + rescale, and K/V DMA across an explicit cluster schedule so each work class overlaps the others' latency. 8 warps per block; each warp owns a Q_BLOCK_SIZE-row stripe of Q.

Cluster schedule

The main loop runs 8 clusters per iteration and advances j by 2, so each iter processes two K/V tiles. Each cluster ends in a bare s_barrier.

ClusterWork
C0QK[j-2] + tail softmax of tile (j-3)
C1DMA K[j] + LDS→register V[j-3]
C2PV[j-3] strip-interleaved with partial softmax(j-2)
C3DMA V[j-1] + LDS→register K[j-1]
C4QK[j-1] + tail softmax of tile (j-2)
C5DMA K[j+1] + LDS→register V[j-2] + causal mask
C6PV[j-2] strip-interleaved with partial softmax(j-1)
C7DMA V[j] + LDS→register K[j]

Whole-tile K is pre-loaded one iteration ahead into the persistent k_reg, so the QK clusters (C0/C4) contain MFMAs + VALU only — no in-cluster ds_read. The prologue primes the pipeline and runs QK[0] + partial softmax; the 13-cluster epilogue drains the final four tiles N-4..N-1, with whole-V PV (no strip split) and an unconditional normalizer rescale before the o / norm_vec divide.

Key design choices

  • Whole-tile K pre-load with consumer-side waitcnt drains. Each MFMA-consumer helper opens with s_waitcnt[lgkmcnt=0]() so SIInsertWaitcnts treats the cluster as a bracket reset and the per- consumer lgkmcnt staircase collapses (see KB patterns/amd-explicit-lgkmcnt-drain-consumer-cluster).

  • Kernel-scope BF16 P-cache. Each softmax bulk-casts FP32 att to one persistent att_block_bf16 register tile reused by the subsequent PV (see KB known-limitations/llvm-amdgpu-cast-rematerialization).

  • Lazy rescale (RESCALE_THRESHOLD=8). In C2/C6, when the running max grows by more than 8 log2 units, o_reg *= scale_vec fires between PV strip 0 and strips 1-3 — strips 1-3 then contribute at the old scale into an already-rescaled accumulator. The 8 log2 cap bounds the inconsistency. When rv_all_below reports no lane exceeded the threshold, the rescale is skipped entirely. The epilogue rescales unconditionally with scale_vec initialized to ones, so the multiply is identity when no rescale ever fired.

  • Causal mask placement. Tiles 0 (prologue), (j - 1) for each main-loop iter (C5), and N - 3, N - 2, N - 1 (epilogue). Tiles 1, 3, … in the main-loop range are unmasked by design; the max_num_tiles cap guarantees those positions are naturally fully unmasked.

  • Output transpose. col_l → row_l is a zero-cost re-tag of the same register storage — no cross-lane permute, no data motion.

  • GQA-aware head remap. head_idx is (block_x % GROUP) * NUM_KV_HEADS + (block_x / GROUP) — the transpose over the (NUM_KV_HEADS, GROUP) rectangle — so adjacent blocks visit different KV heads across CUs/XCDs. Bijective for any NUM_HEADS == GROUP * NUM_KV_HEADS; reduces to identity at MHA (GROUP=1) and MQA (NUM_KV_HEADS=1).

The cluster decomposition and overlap pattern are inspired by the HipKittens project's attention kernel (https://github.com/HazyResearch/HipKittens).

Structs

Functions