IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mla_decode

Gfx950 MLA (Multi-Latent Attention) decode kernel, built on KVBuffer.

Forked from mha_decode.mojo so the MLA-specific pipeline (K==V SMEM aliasing, output_depth < depth rope tail, BNΓ—64 SMEM block split for depth=576) can evolve independently from the MHA path.

Key invariants (enforced via mla_decode callsite in mla.mojo):

  • mla_kv_alias = True: V reads from K's swizzled SMEM, so no V DMA.
  • shared_kv = True (required by mla_kv_alias): V SMEM is a bitcast view of K SMEM (see attention.mojo constructor).
  • BN = 128 β‡’ double_buffer_k_only = False: single-buffer K loop.
  • output_depth ≀ depth: per-warp V buffer slices the V-nope portion out of K's full-depth SMEM.

Recipe (see process_tile):

  • Wait K β†’ optional FULL_MASK skip β†’ load K from SMEM β†’ mma_qk β†’ softmax β†’ write P to SMEM β†’ barrier β†’ load V (= aliased K) from SMEM β†’ prefetch next K β†’ update output β†’ mma_pv β†’ iter-end drain.

NOTE: process_tile, prefetch_next, the empty-partition handling, the softmax block, and the epilogue are largely line-for-line shared with mha_decode.mojo; the two files only diverge in the K==V aliasing path, the single-buffer-K loop, and the bk_smem<BK K SMEM split. Bug fixes to the shared logic (softmax numerics, barrier placement, etc.) must be mirrored to both files until the kernels are re-converged behind a shared helper.