Skip to main content

Mojo struct

MLA_SM100_Decode

@register_passable(trivial) struct MLA_SM100_Decode[q_type: DType, KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, config: MLA_SM100_Decode_Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, TrivialRegisterType

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

AccumType

comptime AccumType = get_accum_type[q_type]()

BlockElems

comptime BlockElems = (config * config)

bytes_per_element

comptime bytes_per_element = size_of[q_type]()

Common_MLA_Op

comptime Common_MLA_Op = MLA_SM100_Decode_Common[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged]

kv_type

comptime kv_type = KVLUTType.dtype

KVStageElems

comptime KVStageElems = (MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].NumQKBlocks * MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].BlockElems)

NumQKBlocks

comptime NumQKBlocks = (config // config)

NumVOBlocks

comptime NumVOBlocks = (config // config)

O_M

comptime O_M = (config * 2)

O_N

comptime O_N = (config // 2)

OTMemTile

comptime OTMemTile = TMemTile[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].O_M, MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].O_N]

output_tile_width

comptime output_tile_width = ((config // 2) * (4 // size_of[output_type]()))

S_M

comptime S_M = (config * 2)

S_N

comptime S_N = (config // 2)

STMemTile

comptime STMemTile = TMemTile[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].S_M, MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].S_N]

UMMAPVSS

comptime UMMAPVSS = DecodeSM100PVSS[q_type, MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, config=config]

UMMAQKTSS

comptime UMMAQKTSS = DecodeSM100QKTSS[q_type, MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].AccumType, config=config]

Methods

kernel

static kernel(q_tma: TMATensorTile[q_type, tile_layout_k_major[q_type, config.BM, config.BK0, config.swizzle_mode](), _tma_desc_tile_layout[q_type, 2, IndexList[2, DType.int64](config.BM, config.BK0, Tuple[]()), config.swizzle_mode]()], k_tma: TMATensorTile[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, _split_last_layout[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type](IndexList[3, DType.int64](config.BK1, 1, config.BK0, Tuple[]()), TensorMapSwizzle.SWIZZLE_64B, True), _ragged_desc_layout[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type](IndexList[3, DType.int64](config.BK1, 1, config.BK0, Tuple[]()), TensorMapSwizzle.SWIZZLE_64B)], o_tma: TMATensorTile[output_type, tile_layout_k_major[output_type, config.out_rows, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[output_type, 2, IndexList[2, DType.int64](config.out_rows, config.BN, Tuple[]()), config.swizzle_mode]()], kv_lut: KVLUTType, scale: Float32, batch_size: Int, q_max_seq_len: Int, num_partitions: Int, max_cache_valid_length: Int, mla_decode_pack: MLA_Decode_Pack[ValidLengthType, MaskType, ScoreModType])

load_kv

static load_kv(tma: TMATensorTile[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, _split_last_layout[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type](IndexList[3, DType.int64](config.BK1, 1, config.BK0, Tuple[]()), TensorMapSwizzle.SWIZZLE_64B, True), _ragged_desc_layout[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type](IndexList[3, DType.int64](config.BK1, 1, config.BK0, Tuple[]()), TensorMapSwizzle.SWIZZLE_64B)], smem: UnsafePointer[Scalar[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], col_start: Scalar[DType.uint], row_start: Scalar[DType.uint])

load_q

static load_q(tma: TMATensorTile[q_type, tile_layout_k_major[q_type, config.BM, config.BK0, config.swizzle_mode](), _tma_desc_tile_layout[q_type, 2, IndexList[2, DType.int64](config.BM, config.BK0, Tuple[]()), config.swizzle_mode]()], smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], col_start: Scalar[DType.uint], row_start: Scalar[DType.uint])

load

static load(q_tma: TMATensorTile[q_type, tile_layout_k_major[q_type, config.BM, config.BK0, config.swizzle_mode](), _tma_desc_tile_layout[q_type, 2, IndexList[2, DType.int64](config.BM, config.BK0, Tuple[]()), config.swizzle_mode]()], k_tma_fp8: TMATensorTile[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type, _split_last_layout[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type](IndexList[3, DType.int64](config.BK1, 1, config.BK0, Tuple[]()), TensorMapSwizzle.SWIZZLE_64B, True), _ragged_desc_layout[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type](IndexList[3, DType.int64](config.BK1, 1, config.BK0, Tuple[]()), TensorMapSwizzle.SWIZZLE_64B)], kv_lut: KVLUTType, q_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem_fp8: UnsafePointer[Scalar[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, 1, (WARPGROUP_SIZE + 2)], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType])

convertFP8ToBF16

static convertFP8ToBF16(kv_smem_fp8: UnsafePointer[Scalar[MLA_SM100_Decode[q_type, KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, ragged].kv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem_bf16: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, 1, (WARPGROUP_SIZE + 2)], kv_cvt2mma_pipe: KVPipelineGeneric[config.num_kv_stages, 1, WARPGROUP_SIZE, 2], num_k_tiles: Int)

mmaQK

static mmaQK(tmem_addr: UInt32, q_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], s_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], kv_cvt2mma_pipe: KVPipelineGeneric[config.num_kv_stages, 1, WARPGROUP_SIZE, 2], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, 1, (WARPGROUP_SIZE + 2)], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType])

mmaPV

static mmaPV(tmem_addr: UInt32, kv_smem: UnsafePointer[Scalar[q_type], MutAnyOrigin, address_space=AddressSpace.SHARED], p_bars: DecodeSM100MiscMBars[2, WARPGROUP_SIZE, 1], o_bars: DecodeSM100MiscMBars[1, 1, WARPGROUP_SIZE], kv_cvt2mma_pipe: KVPipelineGeneric[config.num_kv_stages, 1, WARPGROUP_SIZE, 2], kv_load2cvt_pipe: KVPipelineGeneric[config.num_kv_stages, 1, 1, (WARPGROUP_SIZE + 2)], offset_position: OffsetPosition[config, KVLUTType, ragged, _is_cache_length_accurate, ValidLengthType])

Was this page helpful?