For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
mla_decode
Gfx950 MLA (Multi-Latent Attention) decode kernel, built on KVBuffer.
Forked from mha_decode.mojo so the MLA-specific pipeline (K==V SMEM
aliasing, output_depth < depth rope tail, BNΓ64 SMEM block split for
depth=576) can evolve independently from the MHA path.
Key invariants (enforced via mla_decode callsite in mla.mojo):
mla_kv_alias = True: V reads from K's swizzled SMEM, so no V DMA.shared_kv = True(required bymla_kv_alias): V SMEM is a bitcast view of K SMEM (seeattention.mojoconstructor).BN = 128βdouble_buffer_k_only = False: single-buffer K loop.output_depth β€ depth: per-warp V buffer slices the V-nope portion out of K's full-depth SMEM.
Recipe (see process_tile):
- Wait K β optional FULL_MASK skip β load K from SMEM β mma_qk β softmax β write P to SMEM β barrier β load V (= aliased K) from SMEM β prefetch next K β update output β mma_pv β iter-end drain.
NOTE: process_tile, prefetch_next, the empty-partition handling,
the softmax block, and the epilogue are largely line-for-line shared
with mha_decode.mojo; the two files only diverge in the K==V
aliasing path, the single-buffer-K loop, and the bk_smem<BK K SMEM
split. Bug fixes to the shared logic (softmax numerics, barrier
placement, etc.) must be mirrored to both files until the kernels are
re-converged behind a shared helper.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!