Skip to main content

Mojo module

mha_decode

Gfx950 MHA decode kernel, built on KVBuffer.

Ported from amd/mha.mojo::mha_decoding. Uses a full-depth KVBuffer for K (optionally double-buffered via double_buffer_k_only) and a two-view V setup sharing one SMEM region: v_dma_buffer (full BN x output_depth) for cooperative DRAM→LDS, v_buffer (per-warp BN x depth_per_warp) for LDS→REG and MMA.

Recipe (see process_tile):

  • Wait K (and V unless shared_kv) → optional FULL_MASK skip → load K from SMEM → mma_qk → softmax → write P to SMEM → barrier → (DMA V now if shared_kv) → load V from SMEM → prefetch next K+V → update output → mma_pv → iter-end drain.

shared_kv (depth>256) overlays V onto K's SMEM and defers V DMA until K is consumed. double_buffer_k_only pings K between slots 0/1 while V always lives in slot 0.

Was this page helpful?