For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
msa_combine
MSA combine (LSE-merge) kernel for SM100 (B200).
The block-major forward (msa_2q._msa_sm100_block_major) emits, per selected
(query, kv-block) edge, a split_slot partial: O_partial[split_slot, q, h, :]
(the block-local softmax(scale*Q.Kα΅)*V) and LSE_partial[split_slot, q, h] (the
block's natural-log log-sum-exp). A query that selected C distinct blocks has
its C partials in slots 0..C-1. This kernel reduces them across slots into
the final O via the standard streaming-softmax merge:
m = max_s LSE_s; w_s = exp(LSE_s - m); Z = sum_s w_s O = (1/Z) * sum_s w_s * O_partial[s] (fp32 accum -> BF16) LSE = ln(Z) + m
C comes from split_counts[b, qloc, h]. Degenerate C == 0 (no block
selected this query) or Z == 0/non-finite -> O = 0, LSE = -inf (the
numerator-zero convention the forward + oracle share).
Row-packed grid: one CTA reduces a tile_m = 64 tile of flat (q,h) rows (batch
on grid.z), ~64x fewer CTAs than one-per-(q,h). O_partial is cp.async-pipelined;
LSE is one cp.async load with the s < count mask folding unused slots to -inf;
the split reduce is a warp-shuffle in the base-2 domain the fwd emits, with a
max_valid_split short-circuit. The fwd stored each natural depth col at a
STG.128 fake col (see real_to_fake), so the write-back does a fake->real SMEM
scatter then STG.128, landing O in natural depth order.
Functionsβ
- β
combine_max_splits: - β
fake_to_real: - β
msa_combine_dispatch: Launches the MSA combine: grid(ceil(max_seqlen_q*head_q/64), 1, batch), one CTA per row-tile per batch, 256 threads each. Reduces theO_partial/LSE_partialsplit slots into the final O / LSE.max_splitsis the comptime SMEM/register sizing (>= the runtimetopk); the default 32 covers all currently dispatched topk values. - β
real_to_fake:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!