Skip to main content

Mojo struct

SM100MLA

@register_passable(trivial) struct SM100MLA[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, SchedulerType: MHATileScheduler, config: FA4Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, _ndbuffer_mha_operand: Bool]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

accum_type

comptime accum_type = DType.float32

BM

comptime BM = config.BM

BN

comptime BN = config.BN

cache_depth

comptime cache_depth = 576

cta_group

comptime cta_group = 1

depth

comptime depth = config.depth

group

comptime group = config.group

k_bytes

comptime k_bytes = (SIMD(SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dt_size) * SIMD((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].swizzle_granularity * config)))

k_elements

comptime k_elements = SIMD((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].swizzle_granularity * config))

k_rope_depth

comptime k_rope_depth = 64

kv_depth

comptime kv_depth = (config - 64)

KVPipelineType

comptime KVPipelineType = StagedPipeline[config.num_kv_stages, config.num_qk_stages]

MiscMBarsType

comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False]

MMA_K

comptime MMA_K = 16

MMA_M

comptime MMA_M = (config // 2)

num_m_mmas

comptime num_m_mmas = 2

num_q_heads

comptime num_q_heads = config.num_q_heads

num_qk_stages

comptime num_qk_stages = config.num_qk_stages

padded_depth

comptime padded_depth = config.padded_depth

page_size

comptime page_size = KVLUTType.page_size

PositionType

comptime PositionType = MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_dt_size

comptime qkv_dt_size = size_of[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type]()

qkv_type

comptime qkv_type = KVLUTType.dtype

qo_bytes

comptime qo_bytes = SIMD((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dt_size * SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qo_elements))

qo_elements

comptime qo_elements = (SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].padded_depth * SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M)

simd_size

comptime simd_size = simd_width_of[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type]()

swizzle_granularity

comptime swizzle_granularity = (config.swizzle_mode.bytes() // SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dt_size)

UMMA0Type

comptime UMMA0Type = SM100TensorAccumulatorSS[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type, DType.float32, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].depth, swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, num_stages=SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]

UMMA1Type

comptime UMMA1Type = SM100TensorAccumulatorTS[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type, DType.float32, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, config.swizzle_mode, transpose_b=False, num_stages=SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]

v_bytes_per_mma

comptime v_bytes_per_mma = SIMD(((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dt_size * 16) * config))

Methods

mla_prefill_kernel

static mla_prefill_kernel[KRopeType: MHAOperand](q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config)], k_rope_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList(config.BN, 1, 64, Tuple()), TensorMapSwizzle.SWIZZLE_128B, True), _ragged_desc_layout[KVLUTType.dtype](IndexList(config.BN, 1, 64, Tuple()), TensorMapSwizzle.SWIZZLE_128B)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config)], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth], kv_lut: KVLUTType, k_rope_lut: KRopeType, scale: Float32, batch_size: UInt32, pack: Pack[MaskType, ScoreModType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])

correction

static correction(tmem_addr: UInt32, mbars: FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False], score_row: UInt32, num_keys: UInt32, mask: MaskType, correction_smem_arg: UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED])

softmax

static softmax(tmem_addr: UInt32, warp_idx: UInt32, mbars: FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False], score_row: UInt32, seq_info: SeqInfo, mask: MaskType, num_keys: UInt32, scale: Float32, score_mod: ScoreModType, max_seq_len: UInt32, ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth], o_smem: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], correction_smem_arg: UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED])

scale_write_output

static scale_write_output(local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, inv_row_sum: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], o_tmem: TMemTile[DType.float32, (SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BM // 2), SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth], consumer_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)

mask_status

static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

load

static load[KRopeType: MHAOperand](mbars: FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config)], k_rope_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList(config.BN, 1, 64, Tuple()), TensorMapSwizzle.SWIZZLE_128B, True), _ragged_desc_layout[KVLUTType.dtype](IndexList(config.BN, 1, 64, Tuple()), TensorMapSwizzle.SWIZZLE_128B)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList(config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, Tuple()), config)], kv_lut: KVLUTType, k_rope_lut: KRopeType, q_smem: UnsafePointer[Scalar[KVLUTType.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

descriptor_q

static descriptor_q(q_smem: UnsafePointer[Scalar[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

mma

static mma(tmem_addr: UInt32, mbars: FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False], score_row: UInt32, num_keys: UInt32, mask: MaskType, q_smem: UnsafePointer[Scalar[KVLUTType.dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

Was this page helpful?