IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

msa

Graph-op bindings for the MiniMax-M3 block-sparse attention (MSA) kernels.

Target: NVIDIA SM100 (B200), BF16, head_dim 128, paged KV with page_size 128 (== block size BN), single index-K head. Two ops:

  • mo.msa.indexer.ragged.paged -> sparse_indexer_{prefill,decode}
  • mo.msa.attention.ragged.paged -> msa_sm100_decode (decode) / msa_sm100_prefill_{plan,run} (prefill)

Each op takes the same arguments for prefill and decode and picks the kernel at runtime from kv_collection.max_seq_length (the max number of new query tokens in the batch): == 1 is a single-token decode step, anything larger is a prefill / context-encoding step. (Unlike the DeepSeek MLA indexer, the MSA prefill and decode paths take the same operands, so they need only one op each rather than separate prefill/decode entry points.)

The indexer op emits top-k block ids per (index head, token); the attention op consumes those block ids (d_indices) to gather a sparse band of KV blocks from the main paged cache. Both K caches (index-K and main-KV) are BF16 with no scales, so they build with generic_get_paged_cache (NOT the _with_scales variant the MLA FP8 indexer uses).

Modeled on the MLA FP8 indexer registration in attention.mojo (mo.mla.indexer.ragged.float8.paged) for the comptime cache-param extraction

  • paged-collection build, and on the in-tree MSA tests (Kernels/test/msa/test_msa_sm100_d128_decode_paged.mojo, test_msa_sm100_d128_prefill_device_csr.mojo) for the exact call shapes.

Causal masking is a no-op for seq_len==1 decode: the single query sits at the END of the sequence, so every selected (past) KV position is causal-valid and nothing is masked. The op therefore passes kv_logical_pos=None (in-kernel causal masking OFF) and only carries q_positions for the future spec-decode path. In-kernel kv_logical_pos masking is only meaningful for seq_len>1 spec decode (multiple query tokens, where a query can precede some selected KV).

TODO(seq_len>1): enable in-kernel kv_logical_pos masking when spec decode (seq_len>1) is supported -- validate the diagonal (partial) block with a logit check before wiring kv_logical_pos through.

Structsโ€‹