IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

msa_sm100_prefill_dispatch

def msa_sm100_prefill_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, output_type: DType, PartitionType: MHAPartitionScheme, //, config: MHAConfig[config.dtype], group: Int, ragged: Bool, _is_cache_length_accurate: Bool](output: DeviceBuffer[output_type], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, d_indices: UnsafePointer[Int32, MutAnyOrigin], indices_stride: Int, total_q: Int, mask: MaskType, valid_length: DeviceBuffer[DType.uint32], scale: Float32, kv_input_row_offsets: OptionalReg[TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]], partition: PartitionType, ctx: DeviceContext, kv_logical_pos: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, valid_key: OptionalReg[UInt32] = None)

Per-token sparse MHA prefill for SM100.

One CTA per (query token, kv-head): the total_q query tokens are enumerated on grid.x through the _msa_sm100 body (StaticInt[1] tiles). Each CTA reads its own index row from d_indices ([head_kv, total_q, topk]), keyed on the global query row. Causality is either baked into the indices (both causal inputs None) or applied in-kernel (pass kv_logical_pos / q_positions).

ragged is a host-packing concern only: each query row owns its own index row, so the kernel maps block_idx.x straight to the global row and never consults cu_seqlens (total_q rides grid.x, no 65535 grid.z cap). No split-K. BF16 only, fixed topk-blocks.