Mojo module
mha_prefill
Unified gfx950 MHA prefill kernel.
Handles BF16+FP8, any mask, depth∈{64,128,256,512}, with and without sink.
Consolidates the two older kernels (pipelined mha_prefill and
non-pipelined mha_prefill_gfx950) into one, matching amd/mha.mojo
from main but built on KVBuffer (our version of KVBuffer).
Recipe:
- Double-buffered LDS slots (slot ping-pong) to overlap DRAM→LDS with MMA.
- V registers loaded after QK (V-load-later) to reduce peak VGPR use.
- Per-iteration: wait K → mma_qk → prefetch next tile → split softmax → wait V → mma_pv.
- Softmax picks between prescaled / deferred-scale (_fma) / plain based on prescale_q + depth + mask.apply_log2e_after_mask.
- Non-causal masks check for FULL_MASK tiles per-iteration and skip.
Sink is handled entirely in Attention.__init__ (rowmax/rowsum are
pre-filled from sink_weights); the kernel body needs no sink-specific code.
Functions
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!