Mojo struct
MLA_SM100_Decode
@register_passable(trivial)
struct MLA_SM100_Decode[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, config: MLA_SM100_Decode_Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, _use_valid_length: Bool = False, ragged: Bool = False]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
AccumType
comptime AccumType = get_accum_type[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type]()
BlockElems
comptime BlockElems = (config * config)
bytes_per_element
comptime bytes_per_element = size_of[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type]()
KVStageElems
comptime KVStageElems = (MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].NumQKBlocks * MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].BlockElems)
NumQKBlocks
comptime NumQKBlocks = (config // config)
NumVOBlocks
comptime NumVOBlocks = (config // config)
O_M
comptime O_M = (config * 2)
O_N
comptime O_N = (config // 2)
OTMemTile
comptime OTMemTile = TMemTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType, MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].O_M, MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].O_N]
output_tile_width
comptime output_tile_width = ((config // 2) * (4 // size_of[output_type]()))
qkv_type
comptime qkv_type = KVLUTType.dtype
S_M
comptime S_M = (config * 2)
S_N
comptime S_N = (config // 2)
STMemTile
comptime STMemTile = TMemTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType, MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].S_M, MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].S_N]
UMMAPVSS
comptime UMMAPVSS = DecodeSM100PVSS[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType, config=config]
UMMAQKTSS
comptime UMMAQKTSS = DecodeSM100QKTSS[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType, config=config]
Methods
kernel
static kernel(q_tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, tile_layout_k_major[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, config.BM, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, 2, IndexList[2, DType.int64](config.BM, config.BN, Tuple[]()), config.swizzle_mode]()], k_tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, _split_last_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config, True), _ragged_desc_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config)], o_tma: TMATensorTile[output_type, tile_layout_k_major[output_type, config.out_rows, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[output_type, 2, IndexList[2, DType.int64](config.out_rows, config.BN, Tuple[]()), config.swizzle_mode]()], kv_lut: KVLUTType, scale: Float32, batch_size: Int, num_partitions: Int, max_cache_valid_length: Int, mla_decode_pack: MLA_Decode_Pack[ValidLengthType, MaskType, ScoreModType])
load_kv
static load_kv(tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, _split_last_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config, True), _ragged_desc_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config)], smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], col_start: UInt, row_start: UInt)
load_q
static load_q(tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, tile_layout_k_major[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, config.BM, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, 2, IndexList[2, DType.int64](config.BM, config.BN, Tuple[]()), config.swizzle_mode]()], smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], col_start: UInt, row_start: UInt)
load
static load(q_tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, tile_layout_k_major[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, config.BM, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, 2, IndexList[2, DType.int64](config.BM, config.BN, Tuple[]()), config.swizzle_mode]()], k_tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, _split_last_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config, True), _ragged_desc_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config)], kv_lut: KVLUTType, q_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], mut kv_prod: DecodeKVProducer[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, config], offset_position: OffsetPosition[config, KVLUTType, ragged, _use_valid_length, _is_cache_length_accurate, ValidLengthType])
mmaQK
static mmaQK(tmem_addr: UInt32, q_tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, tile_layout_k_major[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, config.BM, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, 2, IndexList[2, DType.int64](config.BM, config.BN, Tuple[]()), config.swizzle_mode]()], k_tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, _split_last_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config, True), _ragged_desc_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config)], kv_lut: KVLUTType, q_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], kv_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mbar_q: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], s_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], mut kv_cons: DecodeKVConsumer[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, config], offset_position: OffsetPosition[config, KVLUTType, ragged, _use_valid_length, _is_cache_length_accurate, ValidLengthType])
mmaPV
static mmaPV(tmem_addr: UInt32, k_tma: TMATensorTile[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, _split_last_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config, True), _ragged_desc_layout[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config)], kv_lut: KVLUTType, kv_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], p_bars: DecodeSM100MiscMBars[2, WARPGROUP_SIZE, 1], o_bars: DecodeSM100MiscMBars[1, 1, WARPGROUP_SIZE], mut kv_cons: DecodeKVConsumer[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type, config], offset_position: OffsetPosition[config, KVLUTType, ragged, _use_valid_length, _is_cache_length_accurate, ValidLengthType])
clamped_index_coordinate
static clamped_index_coordinate(var prompt_idx: UInt32, var q_head_idx: UInt32, var q_idx_abs: UInt32, var col: UInt32, var tile_key_base: UInt32, var num_keys: Int, var cache_start_pos: UInt32) -> IndexList[4, element_type=DType.uint32]
Returns:
apply_mask
static apply_mask[half_load: Int, masked: Bool](tiles_done: Int, col0: Int, num_keys: Int, s_row: LayoutTensor[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType, Layout.row_major(half_load), MutAnyOrigin, address_space=AddressSpace.LOCAL], mask: MaskType, score_mod: ScoreModType, prompt_idx: UInt32, q_head_idx: UInt32, score_row: UInt32, max_seq_len: UInt32, start_pos: UInt32, cache_start_pos: UInt32) -> Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType]
Returns:
Softmax
static Softmax(tmem_addr: UInt32, s_bars: DecodeSM100MiscMBars[2, 1, WARPGROUP_SIZE], p_bars: DecodeSM100MiscMBars[2, WARPGROUP_SIZE, 1], kv_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED], max_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType], MutAnyOrigin, address_space=AddressSpace.SHARED], li_smem: UnsafePointer[Scalar[MLA_SM100_Decode[KVLUTType, output_type, MaskType, ScoreModType, config, use_score_mod, ValidLengthType, _is_cache_length_accurate, _use_valid_length, ragged].AccumType], MutAnyOrigin, address_space=AddressSpace.SHARED], c_bars: DecodeSM100MiscMBars[1, WARPGROUP_SIZE, WARPGROUP_SIZE], li_bars: DecodeSM100MiscMBars[1, WARPGROUP_SIZE, WARPGROUP_SIZE], num_k_tiles: Int, num_keys: Int, scale: Float32, mask: MaskType, score_mod: ScoreModType, prompt_idx: UInt32, max_seq_len: UInt32)
Correction
static Correction(tmem_addr: UInt32, o_bars: DecodeSM100MiscMBars[1, 1, WARPGROUP_SIZE], mut out_prod: DecodeOutProducer[output_type, config], c_bars: DecodeSM100MiscMBars[1, WARPGROUP_SIZE, WARPGROUP_SIZE], li_bars: DecodeSM100MiscMBars[1, WARPGROUP_SIZE, WARPGROUP_SIZE], num_k_tiles: Int)
store
static store(mut out_cons: DecodeOutConsumer[output_type, config], o_tma: TMATensorTile[output_type, tile_layout_k_major[output_type, config.out_rows, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[output_type, 2, IndexList[2, DType.int64](config.out_rows, config.BN, Tuple[]()), config.swizzle_mode]()])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!