IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

msa_sm100_prefill_b_device_csr_dispatch

def msa_sm100_prefill_b_device_csr_dispatch[q_type: DType, KVType: MHAOperand, output_type: DType, //, config: MHAConfig[config.dtype], group: Int, topk: Int](o: DeviceBuffer[output_type], lse: DeviceBuffer[DType.float32], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, q2k: DeviceBuffer[DType.int32], cu_seqlens_q: List[Int32], cu_seqlens_k: List[Int32], scale: Float32, ctx: DeviceContext, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, seqused_k: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, prebuilt: Optional[PrebuiltSchedule] = None)

End-to-end sparse MHA prefill with the DEVICE CSR builder.

Same external contract as msa_sm100_prefill_b_dispatch, except the query-major selection q2k is already on the device and the reverse-CSR is built on-device (build_k2q_csr_device) instead of host + upload. The forward + combine are byte-for-byte the host-CSR path (they consume the same contract tensors). topk is a comptime parameter (the device builder is templated on it).