Skip to main content

Mojo module

mha_prefill

RDNA Wave32 MHA prefill kernel.

Recipe (per KV tile):

  • K loaded to LDS strip-by-strip; QK MMA fragments emitted per strip.
  • V is prefetched as a side DMA during the second-to-last K strip so it overlaps the QK compute.
  • Mask + online softmax + barriers between QK and PV.
  • P (post-softmax scores) cast and staged in SMEM, then PV MMA reads P from SMEM as the A operand and V from LDS as B.