Mojo struct
MLA_SM100_Decode_Sparse
struct MLA_SM100_Decode_Sparse[q_type: DType, KVLUTType: MHAOperand, output_type: DType, SplitAccumType: OptionalPointer, MaskType: MHAMask, config: MLA_SM100_Decode_Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False, has_attn_sink: Bool = False, has_extra_kv: Bool = False, has_variable_topk: Bool = False]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
AccumType
comptime AccumType = get_accum_type[q_type]()
bf16_bytes_per_element
comptime bf16_bytes_per_element = size_of[DType.bfloat16]()
bf16_type
comptime bf16_type = DType.bfloat16
BlockElems
comptime BlockElems = (config * config)
bytes_per_element
comptime bytes_per_element = size_of[q_type]()
Common_MLA_Op
comptime Common_MLA_Op = MLA_SM100_Decode_Common[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged]
fp8_bytes_per_element
comptime fp8_bytes_per_element = size_of[DType.float8_e4m3fn]()
fp8_type
comptime fp8_type = DType.float8_e4m3fn
gather4_indices_bytes
comptime gather4_indices_bytes = (config * size_of[Int32]())
gather4_num_4row_chunks
comptime gather4_num_4row_chunks = (config // 4)
kv_nope_type
comptime kv_nope_type = KVLUTType.dtype
kv_rope_type
comptime kv_rope_type = DType.bfloat16
kv_type
comptime kv_type = KVLUTType.dtype
KVStageElems
comptime KVStageElems = (MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].NumQKBlocks * MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].BlockElems)
KVStageTotalBytes
comptime KVStageTotalBytes = (MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].NopeStageBytes + MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].RopeStageBytes)
load2cvt_num_producer
comptime load2cvt_num_producer = (1 + 32 if (config > 0) else 0)
nope_gather4_box_w
comptime nope_gather4_box_w = _gather4_box_width[DType.int64, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_tile_width, TensorMapSwizzle.SWIZZLE_NONE]()
nope_gather4_num_col_groups
comptime nope_gather4_num_col_groups = ceildiv((config // 8), MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w)
nope_gather4_tile_width
comptime nope_gather4_tile_width = (config // 8)
NopeStageBytes
comptime NopeStageBytes = (MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].NopeStageElems * MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].fp8_bytes_per_element)
NopeStageElems
comptime NopeStageElems = (MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].NumNopeBlocks * MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].BlockElems)
NumNopeBlocks
comptime NumNopeBlocks = (config // config)
NumQKBlocks
comptime NumQKBlocks = (config // config)
NumRopeBlocks
comptime NumRopeBlocks = (config // config)
NumVOBlocks
comptime NumVOBlocks = (config // config)
output_tile_width
comptime output_tile_width = ((config // 2) * (4 // size_of[output_type]()))
rope_gather4_box_w
comptime rope_gather4_box_w = _gather4_box_width[DType.bfloat16, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_tile_width, TensorMapSwizzle.SWIZZLE_128B]()
rope_gather4_num_col_groups
comptime rope_gather4_num_col_groups = ceildiv(config.rope_depth, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w)
rope_gather4_tile_width
comptime rope_gather4_tile_width = ((config + (config * 2)) // 2)
RopeStageBytes
comptime RopeStageBytes = (MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].RopeStageElems * MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].bf16_bytes_per_element)
RopeStageElems
comptime RopeStageElems = (MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].NumRopeBlocks * MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].BlockElems)
UMMAPVSS
comptime UMMAPVSS = DecodeSM100PVSS[q_type, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].AccumType, config=config]
UMMAQKTSS
comptime UMMAQKTSS = DecodeSM100QKTSS[q_type, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].AccumType, config=config]
Methods
kernel
static kernel(q_tma: TMATensorTile[q_type, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, q_type, IndexList(config, config, __list_literal__=Tuple()), config.swizzle_mode]()], k_nope_tma: TMATensorTile[DType.int64, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple())], k_rope_tma: TMATensorTile[DType.bfloat16, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple())], o_tma: TMATensorTile[output_type, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, output_type, IndexList(config, config, __list_literal__=Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType, scale: Float32, mla_decode_pack: MLA_Decode_Pack[ValidLengthType, MaskType, SplitAccumType], d_indices: UnsafePointer[Int32, MutAnyOrigin], indices_stride: Int, topk_lengths: UnsafePointer[Int32, MutAnyOrigin], scales_ptr: UnsafePointer[Float32, MutAnyOrigin], attn_sink_ptr: UnsafePointer[Float32, MutAnyOrigin], extra_k_nope_tma: TMATensorTile[DType.int64, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple())], extra_k_rope_tma: TMATensorTile[DType.bfloat16, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple())], extra_kv_lut: KVLUTType, extra_d_indices: UnsafePointer[Int32, MutAnyOrigin], extra_topk_lengths: UnsafePointer[Int32, MutAnyOrigin], extra_indices_stride: Int, extra_scales_ptr: UnsafePointer[Float32, MutAnyOrigin], scalar_args: TileTensor[DType.int64, Layout[*?, *?], MutAnyOrigin])
idx_producer
static idx_producer(idx_bars: DecodeSM100MiscMBars[2, 32, 32], idx_smem_base: UnsafePointer[Int32, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_lut: KVLUTType, d_indices: UnsafePointer[Int32, MutAnyOrigin], topk: Int, scales_ptr: UnsafePointer[Float32, MutAnyOrigin], scale_smem_base: UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k, True, has_extra_kv, has_variable_topk], num_orig_blocks: Int, extra_kv_lut: KVLUTType, extra_d_indices: UnsafePointer[Int32, MutAnyOrigin], extra_topk: Int, extra_scales_ptr: UnsafePointer[Float32, MutAnyOrigin])
Index transform producer running on warp 11 (32 threads).
For each KV tile, transforms d_indices to TMA rows and (when blockwise) loads FP32 scales to scale SMEM. Signals idx_bars when each tile's data is ready. Runs 1 tile ahead of warp 8.
load
static load(q_tma: TMATensorTile[q_type, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, q_type, IndexList(config, config, __list_literal__=Tuple()), config.swizzle_mode]()], k_nope_tma: TMATensorTile[DType.int64, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple())], k_rope_tma: TMATensorTile[DType.bfloat16, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple())], q_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_nope_smem_fp8: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_rope_smem_bf16_base: UnsafePointer[BFloat16, MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].load2cvt_num_producer, (WARPGROUP_SIZE + 2)], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k, True, has_extra_kv, has_variable_topk], idx_bars: DecodeSM100MiscMBars[2, 32, 32], idx_smem_base: UnsafePointer[Int32, MutAnyOrigin, address_space=AddressSpace.SHARED], num_orig_blocks: Int, topk: Int, extra_k_nope_tma: TMATensorTile[DType.int64, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].nope_gather4_box_w, __list_literal__=Tuple())], extra_k_rope_tma: TMATensorTile[DType.bfloat16, 2, IndexList(config, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple()), IndexList(1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].rope_gather4_box_w, __list_literal__=Tuple())], extra_topk: Int)
convertFP8ToBF16
static convertFP8ToBF16(kv_nope_smem_fp8: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem_bf16: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].load2cvt_num_producer, (WARPGROUP_SIZE + 2)], kv_cvt2mma_pipe: KVPipelineGeneric[config.num_kv_stages, 1, WARPGROUP_SIZE, 2], num_k_tiles: Int, scale_smem_base: UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED])
mmaQK
static mmaQK(tmem_addr: UInt32, q_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], s_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], kv_cvt2mma_pipe: KVPipelineGeneric[config.num_kv_stages, 1, WARPGROUP_SIZE, 2], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].load2cvt_num_producer, (WARPGROUP_SIZE + 2)], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k, True, has_extra_kv, has_variable_topk])
mmaPV
static mmaPV(tmem_addr: UInt32, kv_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], p_bars: DecodeSM100MiscMBars[2, WARPGROUP_SIZE, 1], o_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], kv_cvt2mma_pipe: KVPipelineGeneric[config.num_kv_stages, 1, WARPGROUP_SIZE, 2], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, MLA_SM100_Decode_Sparse[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_attn_sink, has_extra_kv, has_variable_topk].load2cvt_num_producer, (WARPGROUP_SIZE + 2)], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k, True, has_extra_kv, has_variable_topk])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!