For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SM100MLA
struct SM100MLA[KVLUTType: MHAOperand, KRopeType: MHAOperand, output_dtype: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: MLAConfig[config.qkv_dtype, rope_gmem_dtype=config.rope_gmem_dtype, rope_mma_dtype=config.rope_mma_dtype, scale_dtype=config.scale_dtype], ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, _ndbuffer_mha_operand: Bool]
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_dtypeβ
comptime accum_dtype = DType.float32
BK0β
comptime BK0 = SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qk_depth if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].fused_umma0 else SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_depth
BMβ
comptime BM = config.BM
BNβ
comptime BN = config.BN
cache_depthβ
comptime cache_depth = config.cache_depth
cta_groupβ
comptime cta_group = 1
fused_umma0β
comptime fused_umma0 = not config.fa4_config.use_fused_kv.__bool__() if (SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype == SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype) else (SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype == SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype)
groupβ
comptime group = config.group
MiscMBarsTypeβ
comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.fa4_config.num_qk_stages, num_pv_stages=config.fa4_config.num_pv_stages, num_kv_stages=config.fa4_config.num_kv_stages, use_order_barriers=EnableForcedOrdering, use_fused_kv=config.fa4_config.use_fused_kv, pair_cta=config.fa4_config.pair_cta, num_qo=config.fa4_config.num_qo]
MMA_Mβ
comptime MMA_M = config.fa4_config.MMA_M
nope_depthβ
comptime nope_depth = config.nope_depth
nope_mma_kindβ
comptime nope_mma_kind = UMMAKind.KIND_F16 if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype.is_half_float() else UMMAKind.KIND_F8F6F4
num_pv_stagesβ
comptime num_pv_stages = config.num_pv_stages
num_q_headsβ
comptime num_q_heads = config.num_q_heads
num_qk_stagesβ
comptime num_qk_stages = config.num_qk_stages
padded_depthβ
comptime padded_depth = config.padded_qk_depth
page_sizeβ
comptime page_size = KVLUTType.page_size
PositionTypeβ
comptime PositionType = MHAPosition[config.BM, config.BN, config.qk_depth, config.padded_qk_depth, config.num_q_heads, config.group, False]
q_rope_byte_offsetβ
comptime q_rope_byte_offset = (Int((mul config.fa4_config.MMA_M, config.fa4_config.padded_ov_depth)) * size_of[KVLUTType.dtype]())
qk_depthβ
comptime qk_depth = config.qk_depth
qkv_dt_sizeβ
comptime qkv_dt_size = size_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]()
qkv_dtypeβ
comptime qkv_dtype = KVLUTType.dtype
rope_depthβ
comptime rope_depth = config.rope_depth
rope_gmem_dtypeβ
comptime rope_gmem_dtype = KRopeType.dtype
rope_mma_dtypeβ
comptime rope_mma_dtype
rope_mma_kindβ
comptime rope_mma_kind = UMMAKind.KIND_F16 if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype.is_half_float() else UMMAKind.KIND_F8F6F4
simd_sizeβ
comptime simd_size = simd_width_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]()
UMMA0RopeTypeβ
comptime UMMA0RopeType = SM100TensorAccumulator[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype, DType.float32, config.fa4_config.MMA_M, config.BN, config.rope_depth, a_tmem=False, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_kind, swizzle_a=config.rope_mma_swizzle_mode, swizzle_b=config.rope_mma_swizzle_mode, num_stages=config.num_qk_stages]
UMMA0Typeβ
comptime UMMA0Type = SM100TensorAccumulator[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype, DType.float32, config.fa4_config.MMA_M, config.BN, config.qk_depth if (xor config.fa4_config.use_fused_kv, True) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> *"config.rope_mma_dtype`2", "_mlir_value">>) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> *"config.rope_mma_dtype`2", "_mlir_value">>) else config.nope_depth, a_tmem=False, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_mma_kind, swizzle_a=config.qkv_swizzle_mode, swizzle_b=config.qkv_swizzle_mode, num_stages=config.num_qk_stages]
UMMA1Typeβ
comptime UMMA1Type = SM100TensorAccumulator[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype, DType.float32, config.fa4_config.MMA_M, config.nope_depth, config.BN, a_tmem=True, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_mma_kind, swizzle_b=config.qkv_swizzle_mode, transpose_b=False, num_stages=config.num_pv_stages]
Methodsβ
mask_statusβ
static def mask_status(mask: MaskType, seq_id: UInt32, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus
Returns:
descriptor_qβ
static def descriptor_q(q_smem: UnsafePointer[Scalar[Self.qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
descriptor_q_ropeβ
static def descriptor_q_rope(q_smem: UnsafePointer[Scalar[Self.rope_mma_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!