For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
msa_2q
KV-block-major sparse MHA (MSA) forward kernel for SM100 (B200), BF16, D=128.
The inverse of the query-major msa_1q.mojo: one CTA owns ONE 128-token KV
block (a CSR row) and gathers the queries that selected it. The reverse-CSR
lists, per (batch, kv-block), its attending query tokens; one work-item == one
(head_kv, non-empty CSR row). The CTA bulk-TMAs the block once, then loops
ceil(q_count/BM) Q-tiles (QK -> softmax -> PV per tile) against the resident
block, gather4-loading each tile's queries and scattering results into
O_partial/LSE_partial (combine LSE-merges them into O). Softmax is a single
full tile (no online correction); the epilogue scatters per (query, split_slot).
Scope: flat KV (page_size == 0) or whole-block paged KV (page_size == BN),
BF16, grouped-query qheadperkv in {1, 2, 4, 8, 16} (the group's query heads pack
into the M-tile, sharing one KV load). Diagonal-block causal via per-query
logical position.
comptime valuesβ
Gather4QTileβ
comptime Gather4QTile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int] = TMATensorTile[dtype, Int(2), IndexList(BM, _gather4_box_width[dtype, depth, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[dtype, depth, swizzle_mode](), __list_literal__=NoneType(None))]
Parametersβ
- βdtype (
DType): - βswizzle_mode (
TensorMapSwizzle): - βBM (
Int): - βdepth (
Int):
loggerβ
comptime logger = Logger(stdout, prefix=String(""), source_location=False)
Structsβ
- β
PrebuiltSchedule: The reusable scheduler overlay, already on the device.
Functionsβ
- β
msa_sm100_block_major_dispatch: Dispatch for the block-major MSA forward. Grid is(grid_work, 1, 1)-- a host-known capacity >= the on-devicework_count, so no host readback / sync is needed. CTAs pastwork_count_ptr[0]are idle and PDL-early-exit. ProducesO_partial/LSE_partial; PDL (OVERLAP_AT_END) chains this fwd with the combine kernel that merges them into O. - β
msa_sm100_prefill_b_device_csr_dispatch: End-to-end sparse MHA prefill with the DEVICE CSR builder. - β
msa_sm100_prefill_b_dispatch: End-to-end KV-block-major sparse MHA prefill for SM100.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!