IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

msa_sm100_dispatch

def msa_sm100_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig[config.dtype], group: Int, ragged: Bool, _is_cache_length_accurate: Bool, mask_unselected: Bool = False](output: DeviceBuffer[output_type], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, d_indices: UnsafePointer[Int32, MutAnyOrigin], indices_stride: Int, num_rows_q: Int, mask: MaskType, valid_length: DeviceBuffer[DType.uint32], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, kv_logical_pos: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, valid_key: OptionalReg[UInt32] = None)

Dispatch entry for the SM100 block-sparse MHA decode kernel.

The KV bulk-TMA tiles are dense, but each tile's row base is chosen by a block id from d_indices (indices_stride == topk BLOCKS). KVType is generic over MHAOperand, so the same path serves a flat LayoutTensorMHAOperand or a whole-block paged KVCacheMHAOperand (page_size == BN; the page table resolves each block). Split-K supported (pass a SplitKPartition; reduce with the shared mha_splitk_reduce). BF16 only, fixed topk-blocks.