For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
msa_sm100_prefill_b_dispatch
def msa_sm100_prefill_b_dispatch[q_type: DType, KVType: MHAOperand, output_type: DType, //, config: MHAConfig[config.dtype], group: Int](o: DeviceBuffer[output_type], lse: DeviceBuffer[DType.float32], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, q2k: List[Int32], cu_seqlens_q: List[Int32], cu_seqlens_k: List[Int32], topk: Int, scale: Float32, ctx: DeviceContext, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, seqused_k: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, prebuilt: Optional[PrebuiltSchedule] = None)
End-to-end KV-block-major sparse MHA prefill for SM100.
Inverts the query-major selection q2k into a reverse-CSR on the host
(build_k2q_csr), uploads it, launches the block-major forward
(O_partial/LSE_partial per (query, split-slot)), then the combine
(msa_combine) that LSE-merges each query's slots into the final O. The
external contract is the query-major one: q2k + Q/K/V + cu_seqlens +
scale (+ optional q_positions for in-kernel causal) -> O.
The CSR builder caps each work item at a load-balanced q_per_cta
(balanced_target_q_per_cta, ~num_sms*2 items); the fwd CTA loops
ceil(q_count/BM) Q-tiles against its resident KV block (no query dropped).
group == 1, BF16, non-paged KV.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!