IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

msa_2q

KV-block-major sparse MHA (MSA) forward kernel for SM100 (B200), BF16, D=128.

The inverse of the query-major msa_1q.mojo: one CTA owns ONE 128-token KV block (a CSR row) and gathers the queries that selected it. The reverse-CSR lists, per (batch, kv-block), its attending query tokens; one work-item == one (head_kv, non-empty CSR row). The CTA bulk-TMAs the block once, then loops ceil(q_count/BM) Q-tiles (QK -> softmax -> PV per tile) against the resident block, gather4-loading each tile's queries and scattering results into O_partial/LSE_partial (combine LSE-merges them into O). Softmax is a single full tile (no online correction); the epilogue scatters per (query, split_slot).

Scope: flat KV (page_size == 0) or whole-block paged KV (page_size == BN), BF16, grouped-query qheadperkv in {1, 2, 4, 8, 16} (the group's query heads pack into the M-tile, sharing one KV load). Diagonal-block causal via per-query logical position.

comptime values​

Gather4QTile​

comptime Gather4QTile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int] = TMATensorTile[dtype, Int(2), IndexList(BM, _gather4_box_width[dtype, depth, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[dtype, depth, swizzle_mode](), __list_literal__=NoneType(None))]

Parameters​

logger​

comptime logger = Logger(stdout, prefix=String(""), source_location=False)

Structs​

Functions​