For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
msa_sm100_prefill_dispatch
def msa_sm100_prefill_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, output_type: DType, PartitionType: MHAPartitionScheme, //, config: MHAConfig[config.dtype], group: Int, ragged: Bool, _is_cache_length_accurate: Bool](output: DeviceBuffer[output_type], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, d_indices: UnsafePointer[Int32, MutAnyOrigin], indices_stride: Int, total_q: Int, mask: MaskType, valid_length: DeviceBuffer[DType.uint32], scale: Float32, kv_input_row_offsets: OptionalReg[TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]], partition: PartitionType, ctx: DeviceContext, kv_logical_pos: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, valid_key: OptionalReg[UInt32] = None)
Per-token sparse MHA prefill for SM100.
One CTA per (query token, kv-head): the total_q query tokens are enumerated
on grid.x through the _msa_sm100 body (StaticInt[1] tiles). Each CTA
reads its own index row from d_indices ([head_kv, total_q, topk]), keyed
on the global query row. Causality is either baked into the indices (both
causal inputs None) or applied in-kernel (pass kv_logical_pos /
q_positions).
ragged is a host-packing concern only: each query row owns its own index
row, so the kernel maps block_idx.x straight to the global row and never
consults cu_seqlens (total_q rides grid.x, no 65535 grid.z cap). No
split-K. BF16 only, fixed topk-blocks.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!