For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
msa_1q
Unified per-token BLOCK-sparse MHA (MSA) prefill + decode kernel for SM100 (B200), BF16 Q/K/V, D=128.
Each KV tile bulk-TMAs the 128-token block named by d_indices (topk counts
BLOCKS) instead of marching contiguously: one block id -> BN contiguous tokens
at block_id * BN via the same populate + tma_copy_k/tma_copy_v.
One CTA == one (query token, kv-head). Decode (msa_sm100_dispatch) is the
seqlen-1 case (one CTA per (batch, kv-head)); prefill
(msa_sm100_prefill_dispatch) enumerates B*S query tokens through the same
body, one per CTA, the token axis on grid.x so a long prompt skips the 65535
grid.z cap. Decode split-K on; prefill split-K unsupported. Single GQA group,
fixed topk-blocks, precomputed block indices.
KV is flat (page_size == 0) or whole-block paged (page_size == BN; the page
table resolves block_id * BN). A -1 (unselected) block is skipped:
mask_unselected redirects its load to block 0 (no OOB page lookup) and the
softmax poisons its columns.
comptime valuesβ
loggerβ
comptime logger = Logger(stdout, prefix=String(""), source_location=False)
Functionsβ
- β
msa_sm100_dispatch: Dispatch entry for the SM100 block-sparse MHA decode kernel. - β
msa_sm100_prefill_dispatch: Per-token sparse MHA prefill for SM100.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!