IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mha_decode

Gfx950 MHA decode kernel, built on KVBuffer.

Ported from amd/mha.mojo::mha_decoding. Uses a full-depth KVBuffer for K (optionally double-buffered via double_buffer_k_only) and a two-view V setup sharing one SMEM region: v_dma_buffer (full BN x depth) for cooperative DRAM→LDS, v_buffer (per-warp BN x depth_per_warp) for LDS→REG and MMA.

MLA decode (K==V aliasing, output_depth < depth, K-tail rope handling) lives in mla_decode.mojo; this file is MHA-only.

Recipe (see process_tile):

  • Wait K (and V unless shared_kv) → optional FULL_MASK skip → load K from SMEM → mma_qk → softmax → write P to SMEM → barrier → (DMA V now if shared_kv) → load V from SMEM → prefetch next K+V → update output → mma_pv → iter-end drain.

shared_kv (depth>256) overlays V onto K's SMEM and defers V DMA until K is consumed. double_buffer_k_only pings K between slots 0/1 while V always lives in slot 0.