Mojo module
mha_decode
Gfx950 MHA decode kernel, built on KVBuffer.
Ported from amd/mha.mojo::mha_decoding. Uses a full-depth KVBuffer
for K (optionally double-buffered via double_buffer_k_only) and a
two-view V setup sharing one SMEM region: v_dma_buffer (full
BN x output_depth) for cooperative DRAM→LDS, v_buffer (per-warp
BN x depth_per_warp) for LDS→REG and MMA.
Recipe (see process_tile):
- Wait K (and V unless shared_kv) → optional FULL_MASK skip → load K from SMEM → mma_qk → softmax → write P to SMEM → barrier → (DMA V now if shared_kv) → load V from SMEM → prefetch next K+V → update output → mma_pv → iter-end drain.
shared_kv (depth>256) overlays V onto K's SMEM and defers V DMA until
K is consumed. double_buffer_k_only pings K between slots 0/1 while
V always lives in slot 0.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!