IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

msa_combine

MSA combine (LSE-merge) kernel for SM100 (B200).

The block-major forward (msa_2q._msa_sm100_block_major) emits, per selected (query, kv-block) edge, a split_slot partial: O_partial[split_slot, q, h, :] (the block-local softmax(scale*Q.Kα΅€)*V) and LSE_partial[split_slot, q, h] (the block's natural-log log-sum-exp). A query that selected C distinct blocks has its C partials in slots 0..C-1. This kernel reduces them across slots into the final O via the standard streaming-softmax merge:

m = max_s LSE_s; w_s = exp(LSE_s - m); Z = sum_s w_s O = (1/Z) * sum_s w_s * O_partial[s] (fp32 accum -> BF16) LSE = ln(Z) + m

C comes from split_counts[b, qloc, h]. Degenerate C == 0 (no block selected this query) or Z == 0/non-finite -> O = 0, LSE = -inf (the numerator-zero convention the forward + oracle share).

Row-packed grid: one CTA reduces a tile_m = 64 tile of flat (q,h) rows (batch on grid.z), ~64x fewer CTAs than one-per-(q,h). O_partial is cp.async-pipelined; LSE is one cp.async load with the s < count mask folding unused slots to -inf; the split reduce is a warp-shuffle in the base-2 domain the fwd emits, with a max_valid_split short-circuit. The fwd stored each natural depth col at a STG.128 fake col (see real_to_fake), so the write-back does a fake->real SMEM scatter then STG.128, landing O in natural depth order.

Functions​