Mojo struct
MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware
struct MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type: DType, KVLUTType: MHAOperand, output_type: DType, SplitAccumType: OptionalPointer, MaskType: MHAMask, config: MLA_SM100_Decode_Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False, has_per_token_scales: Bool = False]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
AccumType
comptime AccumType = get_accum_type[q_type]()
bf16_bytes_per_element
comptime bf16_bytes_per_element = size_of[DType.bfloat16]()
bf16_type
comptime bf16_type = DType.bfloat16
BlockElems
comptime BlockElems = (config * config)
Common_MLA_Op
comptime Common_MLA_Op = MLA_SM100_Decode_Common[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged]
ContentStageBytes
comptime ContentStageBytes = (MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].ContentStageElems * MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].fp8_bytes_per_element)
ContentStageElems
comptime ContentStageElems = (MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].NumContentBlocks * MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].BlockElems)
fp8_bytes_per_element
comptime fp8_bytes_per_element = size_of[DType.float8_e4m3fn]()
fp8_type
comptime fp8_type = DType.float8_e4m3fn
kv_type
comptime kv_type = KVLUTType.dtype
KVStageTotalBytes
comptime KVStageTotalBytes = (MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].ContentStageBytes + MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].RopeStageBytes)
num_kv_producer
comptime num_kv_producer = 1
num_stages
comptime num_stages = config.num_kv_stages
NumContentBlocks
comptime NumContentBlocks = (config // config)
NumRopeBlocks
comptime NumRopeBlocks = (config // config)
NumVOBlocks
comptime NumVOBlocks = (config // config)
output_tile_width
comptime output_tile_width = ((config // 2) * (4 // size_of[output_type]()))
PStageElems
comptime PStageElems = MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].BlockElems
RopeStageBytes
comptime RopeStageBytes = (MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].RopeStageElems * MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].bf16_bytes_per_element)
RopeStageElems
comptime RopeStageElems = (MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].NumRopeBlocks * MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].BlockElems)
UMMAPVSS
comptime UMMAPVSS = DecodeSM100PVSS_FP8[DType.float8_e4m3fn, MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].AccumType, config=config]
UMMAQKTSS_Content
comptime UMMAQKTSS_Content = DecodeSM100QKTSS_Content_FP8[DType.float8_e4m3fn, MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].AccumType, config=config]
UMMAQKTSS_Rope
comptime UMMAQKTSS_Rope = DecodeSM100QKTSS_Rope_BF16[DType.bfloat16, MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].AccumType, config=config]
Methods
kernel
static kernel(q_nope_tma: TMATensorTile[DType.float8_e4m3fn, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, DType.float8_e4m3fn, IndexList(config, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_64B]()], q_rope_tma: TMATensorTile[DType.bfloat16, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, DType.bfloat16, IndexList(config, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_128B]()], k_content_tma: TMATensorTile[DType.float8_e4m3fn, 3, _padded_shape[3, DType.float8_e4m3fn, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_64B](), _ragged_shape[3, DType.float8_e4m3fn, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_64B]()], k_rope_tma: TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_128B](), _ragged_shape[3, DType.bfloat16, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_128B]()], scale_tma: TMATensorTile[DType.float32, 2, IndexList(1, config, __list_literal__=Tuple())], o_tma: TMATensorTile[output_type, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, output_type, IndexList(config, config, __list_literal__=Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType, scale: Float32, mla_decode_pack: MLA_Decode_Pack[ValidLengthType, MaskType, SplitAccumType], q_scale_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], scalar_args: TileTensor[DType.int64, Layout[*?, *?], MutAnyOrigin])
load
static load(q_nope_tma: TMATensorTile[DType.float8_e4m3fn, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, DType.float8_e4m3fn, IndexList(config, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_64B]()], q_rope_tma: TMATensorTile[DType.bfloat16, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, DType.bfloat16, IndexList(config, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_128B]()], k_content_tma: TMATensorTile[DType.float8_e4m3fn, 3, _padded_shape[3, DType.float8_e4m3fn, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_64B](), _ragged_shape[3, DType.float8_e4m3fn, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_64B]()], k_rope_tma: TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_128B](), _ragged_shape[3, DType.bfloat16, IndexList(config, 1, config, __list_literal__=Tuple()), TensorMapSwizzle.SWIZZLE_128B]()], scale_tma: TMATensorTile[DType.float32, 2, IndexList(1, config, __list_literal__=Tuple())], kv_lut: KVLUTType, q_nope_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], q_rope_smem: UnsafePointer[BFloat16, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_content_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_rope_smem: UnsafePointer[BFloat16, MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_pipeline: KVPipelineGeneric[config.num_kv_stages, 1, 1, 2], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k], scale_smem_base: UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED])
mmaQK
static mmaQK(tmem_addr: UInt32, q_nope_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], q_rope_smem: UnsafePointer[BFloat16, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_content_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_rope_smem: UnsafePointer[BFloat16, MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], s_bars: DecodeSM100MiscMBars[MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].num_stages, 1, WARPGROUP_SIZE], kv_pipeline: KVPipelineGeneric[config.num_kv_stages, 1, 1, 2], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k])
mmaPV
static mmaPV(tmem_addr: UInt32, kv_content_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], p_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], p_bars: DecodeSM100MiscMBars[MLA_SM100_Decode_QKV_FP8_PerTokenScale_RopeAware[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged, has_per_token_scales].num_stages, WARPGROUP_SIZE, 1], o_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], kv_pipeline: KVPipelineGeneric[config.num_kv_stages, 1, 1, 2], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!