IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

msa_sm100_prefill_b_dispatch

def msa_sm100_prefill_b_dispatch[q_type: DType, KVType: MHAOperand, output_type: DType, //, config: MHAConfig[config.dtype], group: Int](o: DeviceBuffer[output_type], lse: DeviceBuffer[DType.float32], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, q2k: List[Int32], cu_seqlens_q: List[Int32], cu_seqlens_k: List[Int32], topk: Int, scale: Float32, ctx: DeviceContext, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, seqused_k: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, prebuilt: Optional[PrebuiltSchedule] = None)

End-to-end KV-block-major sparse MHA prefill for SM100.

Inverts the query-major selection q2k into a reverse-CSR on the host (build_k2q_csr), uploads it, launches the block-major forward (O_partial/LSE_partial per (query, split-slot)), then the combine (msa_combine) that LSE-merges each query's slots into the final O. The external contract is the query-major one: q2k + Q/K/V + cu_seqlens + scale (+ optional q_positions for in-kernel causal) -> O.

The CSR builder caps each work item at a load-balanced q_per_cta (balanced_target_q_per_cta, ~num_sms*2 items); the fwd CTA loops ceil(q_count/BM) Q-tiles against its resident KV block (no query dropped). group == 1, BF16, non-paged KV.