For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
msa
Graph-op bindings for the MiniMax-M3 block-sparse attention (MSA) kernels.
Target: NVIDIA SM100 (B200), BF16, head_dim 128, paged KV with page_size 128 (== block size BN), single index-K head. Two ops:
mo.msa.indexer.ragged.paged->sparse_indexer_{prefill,decode}mo.msa.attention.ragged.paged->msa_sm100_decode(decode) /msa_sm100_prefill_{plan,run}(prefill)
Each op takes the same arguments for prefill and decode and picks the kernel at
runtime from kv_collection.max_seq_length (the max number of new query tokens
in the batch): == 1 is a single-token decode step, anything larger is a
prefill / context-encoding step. (Unlike the DeepSeek MLA indexer, the MSA
prefill and decode paths take the same operands, so they need only one op each
rather than separate prefill/decode entry points.)
The indexer op emits top-k block ids per (index head, token); the attention
op consumes those block ids (d_indices) to gather a sparse band of KV blocks
from the main paged cache. Both K caches (index-K and main-KV) are BF16 with no
scales, so they build with generic_get_paged_cache (NOT the _with_scales
variant the MLA FP8 indexer uses).
Modeled on the MLA FP8 indexer registration in attention.mojo
(mo.mla.indexer.ragged.float8.paged) for the comptime cache-param extraction
- paged-collection build, and on the in-tree MSA tests
(
Kernels/test/msa/test_msa_sm100_d128_decode_paged.mojo,test_msa_sm100_d128_prefill_device_csr.mojo) for the exact call shapes.
Causal masking is a no-op for seq_len==1 decode: the single query sits at the
END of the sequence, so every selected (past) KV position is causal-valid and
nothing is masked. The op therefore passes kv_logical_pos=None (in-kernel
causal masking OFF) and only carries q_positions for the future spec-decode
path. In-kernel kv_logical_pos masking is only meaningful for seq_len>1
spec decode (multiple query tokens, where a query can precede some selected KV).
TODO(seq_len>1): enable in-kernel kv_logical_pos masking when spec decode
(seq_len>1) is supported -- validate the diagonal (partial) block with a logit
check before wiring kv_logical_pos through.
Structsโ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!