Skip to main content

Mojo struct

MLA_SM100_Decode_QKV_FP8

@register_passable(trivial) struct MLA_SM100_Decode_QKV_FP8[q_type: DType, KVLUTType: MHAOperand, output_type: DType, SplitAccumType: OptionalPointer, MaskType: MHAMask, config: MLA_SM100_Decode_Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copy_ctor_is_trivial

comptime __copy_ctor_is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__move_ctor_is_trivial

comptime __move_ctor_is_trivial = True

AccumType

comptime AccumType = get_accum_type[q_type]()

bf16_bytes_per_element

comptime bf16_bytes_per_element = size_of[q_type]()

BlockElems

comptime BlockElems = (config * config)

Common_MLA_Op

comptime Common_MLA_Op = MLA_SM100_Decode_Common[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged]

fp8_bytes_per_element

comptime fp8_bytes_per_element = size_of[DType.float8_e4m3fn]()

fp8_type

comptime fp8_type = DType.float8_e4m3fn

kv_type

comptime kv_type = KVLUTType.dtype

KVStageElems

comptime KVStageElems = (MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].NumQKBlocks * MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].BlockElems)

num_stages

comptime num_stages = config.num_kv_stages

NumQKBlocks

comptime NumQKBlocks = (config // config)

NumVOBlocks

comptime NumVOBlocks = (config // config)

output_tile_width

comptime output_tile_width = ((config // 2) * (4 // size_of[output_type]()))

PStageElems

comptime PStageElems = MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].BlockElems

UMMAPVSS

comptime UMMAPVSS = DecodeSM100PVSS_FP8[DType.float8_e4m3fn, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, config=config]

UMMAQKTSS

comptime UMMAQKTSS = DecodeSM100QKTSS_FP8[DType.float8_e4m3fn, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, config=config]

Methods

kernel

static kernel(q_tma: TMATensorTile[MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, 2, IndexList(config.BM, config.BK0, Tuple()), _default_desc_shape[2, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(config.BM, config.BK0, Tuple()), config.kv_tma_swizzle_mode]()], k_tma: TMATensorTile[MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, 3, _padded_shape[3, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(config.BK1, 1, config.BK0, Tuple()), config.kv_tma_swizzle_mode](), _ragged_shape[3, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(config.BK1, 1, config.BK0, Tuple()), config.kv_tma_swizzle_mode]()], o_tma: TMATensorTile[output_type, 2, IndexList(config.out_rows, config.BN, Tuple()), _default_desc_shape[2, output_type, IndexList(config.out_rows, config.BN, Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType, scale: Float32, mla_decode_pack: MLA_Decode_Pack[ValidLengthType, MaskType, SplitAccumType], scales_ptr: UnsafePointer[Float32, MutAnyOrigin], scalar_args: LayoutTensor[DType.int64, Layout.row_major(4), MutAnyOrigin])

load

static load(q_tma: TMATensorTile[MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, 2, IndexList(config.BM, config.BK0, Tuple()), _default_desc_shape[2, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(config.BM, config.BK0, Tuple()), config.kv_tma_swizzle_mode]()], k_tma: TMATensorTile[MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, 3, _padded_shape[3, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(config.BK1, 1, config.BK0, Tuple()), config.kv_tma_swizzle_mode](), _ragged_shape[3, MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, IndexList(config.BK1, 1, config.BK0, Tuple()), config.kv_tma_swizzle_mode]()], kv_lut: KVLUTType, q_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_pipeline: KVPipelineGeneric[config.num_kv_stages, 1, 1, 2], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k])

mmaQK

static mmaQK(tmem_addr: UInt32, q_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], s_bars: DecodeSM100MiscMBars[MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].num_stages, 1, WARPGROUP_SIZE], kv_pipeline: KVPipelineGeneric[config.num_kv_stages, 1, 1, 2], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k])

mmaPV

static mmaPV(tmem_addr: UInt32, kv_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], p_smem: UnsafePointer[Float8_e4m3fn, MutAnyOrigin, address_space=AddressSpace.SHARED], p_bars: DecodeSM100MiscMBars[MLA_SM100_Decode_QKV_FP8[q_type, KVLUTType, output_type, SplitAccumType, MaskType, config, ValidLengthType, _is_cache_length_accurate, ragged].num_stages, WARPGROUP_SIZE, 1], o_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], kv_pipeline: KVPipelineGeneric[config.num_kv_stages, 1, 1, 2], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType, config.decoding_warp_split_k])

Was this page helpful?