Mojo module
mha_prefill
RDNA Wave32 MHA prefill kernel.
Recipe (per KV tile):
- K loaded to LDS strip-by-strip; QK MMA fragments emitted per strip.
- V is prefetched as a side DMA during the second-to-last K strip so it overlaps the QK compute.
- Mask + online softmax + barriers between QK and PV.
- P (post-softmax scores) cast and staged in SMEM, then PV MMA reads P from SMEM as the A operand and V from LDS as B.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!