For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
msa_sm100_dispatch
def msa_sm100_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig[config.dtype], group: Int, ragged: Bool, _is_cache_length_accurate: Bool, mask_unselected: Bool = False](output: DeviceBuffer[output_type], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, d_indices: UnsafePointer[Int32, MutAnyOrigin], indices_stride: Int, num_rows_q: Int, mask: MaskType, valid_length: DeviceBuffer[DType.uint32], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, kv_logical_pos: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, valid_key: OptionalReg[UInt32] = None)
Dispatch entry for the SM100 block-sparse MHA decode kernel.
The KV bulk-TMA tiles are dense, but each tile's row base is chosen by a
block id from d_indices (indices_stride == topk BLOCKS). KVType is
generic over MHAOperand, so the same path serves a flat
LayoutTensorMHAOperand or a whole-block paged KVCacheMHAOperand
(page_size == BN; the page table resolves each block). Split-K supported
(pass a SplitKPartition; reduce with the shared mha_splitk_reduce).
BF16 only, fixed topk-blocks.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!