Mojo struct
MLA_SM100_Decode_Common
struct MLA_SM100_Decode_Common[q_type: DType, KVLUTType: MHAOperand, output_dtype: DType, SplitAccumType: OptionalPointer, MaskType: MHAMask, config: MLA_SM100_Decode_Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
AccumType
comptime AccumType = get_accum_type[q_type]()
BlockElems
comptime BlockElems = (config * config)
bytes_per_element
comptime bytes_per_element = size_of[q_type]()
kv_type
comptime kv_type = KVLUTType.dtype
KVStageElems
comptime KVStageElems = (MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].NumQKBlocks * MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].BlockElems)
NumQKBlocks
comptime NumQKBlocks = (config // config)
NumVOBlocks
comptime NumVOBlocks = (config // config)
O_M
comptime O_M = (config * 2)
O_N
comptime O_N = (config // 2)
output_tile_width
comptime output_tile_width = ((config // 2) * (4 // size_of[output_dtype]()))
S_M
comptime S_M = (config * 2)
S_N
comptime S_N = (config // 2)
UMMAPVSS
comptime UMMAPVSS = DecodeSM100PVSS[q_type, MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, config=config]
UMMAQKTSS
comptime UMMAQKTSS = DecodeSM100QKTSS[q_type, MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, config=config]
Methods
pdl_early_exit
static pdl_early_exit(split_idx: Int, batch_idx: Int, max_seq_len: Int, out_row_offset: Int, batch_size: Int, lse_accum_split_ptr: SplitAccumType, o_tma: TMATensorTile[output_dtype, 2, IndexList(VariadicList(config.out_rows, config.BN), Tuple()), _default_desc_shape[2, output_dtype, IndexList(VariadicList(config.out_rows, config.BN), Tuple()), config.swizzle_mode]()])
load_kv
static load_kv(tma: TMATensorTile[MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, 3, _padded_shape[3, MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(VariadicList(config.BK1, 1, config.BK0), Tuple()), config.kv_tma_swizzle_mode](), _ragged_shape[3, MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(VariadicList(config.BK1, 1, config.BK0), Tuple()), config.kv_tma_swizzle_mode]()], smem: UnsafePointer[Scalar[MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], col_start: UInt, row_start: UInt)
load_q
static load_q(tma: TMATensorTile[q_type, 2, IndexList(VariadicList(config.BM, config.BK0), Tuple()), _default_desc_shape[2, q_type, IndexList(VariadicList(config.BM, config.BK0), Tuple()), config.swizzle_mode]()], smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], col_start: UInt, row_start: UInt)
apply_mask
static apply_mask[half_load: Int, NonCausalMask: Bool, CausalMask: Bool](tiles_done: Int, col0: Int, num_keys: Int, s_row: TileTensor[MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, Layout[ComptimeInt[half_load], ComptimeInt[1]], MutExternalOrigin, address_space=AddressSpace.LOCAL], mask: MaskType, prompt_idx: UInt32, q_head_idx: UInt32, score_row: UInt32, cache_len: Int, start_pos: UInt32, cache_start_pos: UInt32, kv_start_row: Int = 0) -> Scalar[MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType]
Returns:
Softmax
static Softmax[native_fp8: Bool = False, num_sp_stages: Int = 2, fp8_p_stage_stride: Int = 0, has_per_token_scales: Bool = False](tmem_addr: UInt32, s_bars: DecodeSM100MiscMBars[num_sp_stages, 1, WARPGROUP_SIZE], p_bars: DecodeSM100MiscMBars[num_sp_stages, WARPGROUP_SIZE, 1], p_smem_ptr: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], max_smem: UnsafePointer[Scalar[MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType], MutAnyOrigin, address_space=AddressSpace.SHARED], li_smem: UnsafePointer[Scalar[MLA_SM100_Decode_Common[q_type, KVLUTType, output_dtype, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType], MutAnyOrigin, address_space=AddressSpace.SHARED], out_smem: UnsafePointer[Scalar[output_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], c_bars: DecodeSM100MiscMBars[1, WARPGROUP_SIZE, WARPGROUP_SIZE], corr_done_bars: DecodeSM100MiscMBars[2, WARPGROUP_SIZE, WARPGROUP_SIZE], out_pipeline: OutPipeline[DecodeOutProducer[output_dtype, config].num_out_stages, WARPGROUP_SIZE, 1], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k], scale: Float32, mask: MaskType, prompt_idx: UInt32, lse_accum_split_ptr: SplitAccumType, batch_size: Int, scale_k_smem: UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED] = UnsafePointer(), q_scale_ptr: UnsafePointer[Float32, MutAnyOrigin] = UnsafePointer())
Correction
static Correction(tmem_addr: UInt32, o_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], c_bars: DecodeSM100MiscMBars[1, WARPGROUP_SIZE, WARPGROUP_SIZE], corr_done_bars: DecodeSM100MiscMBars[2, WARPGROUP_SIZE, WARPGROUP_SIZE], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k])
store
static store(out_pipeline: OutPipeline[DecodeOutProducer[output_dtype, config].num_out_stages, WARPGROUP_SIZE, 1], out_smem: UnsafePointer[Scalar[output_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], o_tma: TMATensorTile[output_dtype, 2, IndexList(VariadicList(config.out_rows, config.BN), Tuple()), _default_desc_shape[2, output_dtype, IndexList(VariadicList(config.out_rows, config.BN), Tuple()), config.swizzle_mode]()], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!