IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

msa_1q

Unified per-token BLOCK-sparse MHA (MSA) prefill + decode kernel for SM100 (B200), BF16 Q/K/V, D=128.

Each KV tile bulk-TMAs the 128-token block named by d_indices (topk counts BLOCKS) instead of marching contiguously: one block id -> BN contiguous tokens at block_id * BN via the same populate + tma_copy_k/tma_copy_v.

One CTA == one (query token, kv-head). Decode (msa_sm100_dispatch) is the seqlen-1 case (one CTA per (batch, kv-head)); prefill (msa_sm100_prefill_dispatch) enumerates B*S query tokens through the same body, one per CTA, the token axis on grid.x so a long prompt skips the 65535 grid.z cap. Decode split-K on; prefill split-K unsupported. Single GQA group, fixed topk-blocks, precomputed block indices.

KV is flat (page_size == 0) or whole-block paged (page_size == BN; the page table resolves block_id * BN). A -1 (unselected) block is skipped: mask_unselected redirects its load to block 0 (no OOB page lookup) and the softmax poisons its columns.

comptime values​

logger​

comptime logger = Logger(stdout, prefix=String(""), source_location=False)

Functions​