Mojo struct
SM100MLA
@register_passable(trivial)
struct SM100MLA[KVLUTType: MHAOperand, KRopeType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: MLAConfig, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, _ndbuffer_mha_operand: Bool]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copy_ctor_is_trivial
comptime __copy_ctor_is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__move_ctor_is_trivial
comptime __move_ctor_is_trivial = True
accum_type
comptime accum_type = DType.float32
BM
comptime BM = config.BM
BN
comptime BN = config.BN
cache_depth
comptime cache_depth = config.cache_depth
cta_group
comptime cta_group = 1
depth
comptime depth = config.depth
group
comptime group = config.group
k_rope_depth
comptime k_rope_depth = config.k_rope_depth
kv_depth
comptime kv_depth = config.kv_depth
KVPipelineType
comptime KVPipelineType = StagedPipeline[config.num_kv_stages, config.num_qk_stages]
MiscMBarsType
comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False]
mma_kind
comptime mma_kind = UMMAKind.KIND_F16 if SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type.is_half_float() else UMMAKind.KIND_F8F6F4
MMA_M
comptime MMA_M = (config // 2)
num_m_mmas
comptime num_m_mmas = 2
num_q_heads
comptime num_q_heads = config.num_q_heads
num_qk_stages
comptime num_qk_stages = config.num_qk_stages
padded_depth
comptime padded_depth = config.padded_depth
page_size
comptime page_size = KVLUTType.page_size
PositionType
comptime PositionType = MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, False]
qkv_dt_size
comptime qkv_dt_size = size_of[SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type]()
qkv_type
comptime qkv_type = KVLUTType.dtype
simd_size
comptime simd_size = simd_width_of[SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type]()
UMMA0Type
comptime UMMA0Type = SM100TensorAccumulatorSS[SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type, DType.float32, SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].depth, mma_kind=SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].mma_kind, swizzle_a=config.qkv_swizzle_mode, swizzle_b=config.qkv_swizzle_mode, num_stages=SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]
UMMA1Type
comptime UMMA1Type = SM100TensorAccumulatorTS[SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type, DType.float32, SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth, SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, config.qkv_swizzle_mode, mma_kind=SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].mma_kind, transpose_b=False, num_stages=SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]
Methods
softmax
static softmax(tmem_addr: UInt32, warp_idx: UInt32, mbars: FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False], score_row: UInt32, seq_info: SeqInfo, mask: MaskType, num_keys: UInt32, scale: Float32, max_seq_len: UInt32, ragged_tma_store: RaggedTMA3DTile[output_type, config.output_swizzle_mode, (config // 2), SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth], o_smem: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], correction_smem_arg: UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED])
correction
static correction(tmem_addr: UInt32, mbars: FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, separate_kv=False], score_row: UInt32, num_keys: UInt32, mask: MaskType, correction_smem_arg: UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED])
mask_status
static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus
Returns:
TileMaskStatus
scale_write_output
static scale_write_output(local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, inv_row_sum: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], o_tmem_arg: TMemTile[DType.float32, (SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BM // 2), SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth], ragged_tma_store: RaggedTMA3DTile[output_type, config.output_swizzle_mode, (config // 2), SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].kv_depth], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)
descriptor_q
static descriptor_q(q_smem: UnsafePointer[Scalar[SM100MLA[KVLUTType, KRopeType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!