IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100MLA

struct SM100MLA[KVLUTType: MHAOperand, KRopeType: MHAOperand, output_dtype: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: MLAConfig[config.qkv_dtype, rope_gmem_dtype=config.rope_gmem_dtype, rope_mma_dtype=config.rope_mma_dtype, scale_dtype=config.scale_dtype], ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, _ndbuffer_mha_operand: Bool]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_dtype​

comptime accum_dtype = DType.float32

BK0​

comptime BK0 = SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qk_depth if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].fused_umma0 else SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_depth

BM​

comptime BM = config.BM

BN​

comptime BN = config.BN

cache_depth​

comptime cache_depth = config.cache_depth

cta_group​

comptime cta_group = 1

fused_umma0​

comptime fused_umma0 = not config.fa4_config.use_fused_kv.__bool__() if (SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype == SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype) else (SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype == SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype)

group​

comptime group = config.group

MiscMBarsType​

comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.fa4_config.num_qk_stages, num_pv_stages=config.fa4_config.num_pv_stages, num_kv_stages=config.fa4_config.num_kv_stages, use_order_barriers=EnableForcedOrdering, use_fused_kv=config.fa4_config.use_fused_kv, pair_cta=config.fa4_config.pair_cta, num_qo=config.fa4_config.num_qo]

MMA_M​

comptime MMA_M = (config // 2)

nope_depth​

comptime nope_depth = config.nope_depth

nope_mma_kind​

comptime nope_mma_kind = UMMAKind.KIND_F16 if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype.is_half_float() else UMMAKind.KIND_F8F6F4

num_m_mmas​

comptime num_m_mmas = 2

num_pv_stages​

comptime num_pv_stages = config.num_pv_stages

num_q_heads​

comptime num_q_heads = config.num_q_heads

num_qk_stages​

comptime num_qk_stages = config.num_qk_stages

padded_depth​

comptime padded_depth = config.padded_qk_depth

page_size​

comptime page_size = KVLUTType.page_size

PositionType​

comptime PositionType = MHAPosition[config.BM, config.BN, config.qk_depth, config.padded_qk_depth, config.num_q_heads, config.group, False]

q_rope_byte_offset​

comptime q_rope_byte_offset = ((SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M * config.fa4_config) * size_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]())

qk_depth​

comptime qk_depth = config.qk_depth

qkv_dt_size​

comptime qkv_dt_size = size_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]()

qkv_dtype​

comptime qkv_dtype = KVLUTType.dtype

rope_depth​

comptime rope_depth = config.rope_depth

rope_gmem_dtype​

comptime rope_gmem_dtype = KRopeType.dtype

rope_mma_dtype​

comptime rope_mma_dtype

rope_mma_kind​

comptime rope_mma_kind = UMMAKind.KIND_F16 if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype.is_half_float() else UMMAKind.KIND_F8F6F4

simd_size​

comptime simd_size = simd_width_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]()

UMMA0RopeType​

comptime UMMA0RopeType = SM100TensorAccumulator[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype, DType.float32, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_depth, a_tmem=False, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_kind, swizzle_a=config.rope_mma_swizzle_mode, swizzle_b=config.rope_mma_swizzle_mode, num_stages=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]

UMMA0Type​

comptime UMMA0Type = SM100TensorAccumulator[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype, DType.float32, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BK0, a_tmem=False, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_mma_kind, swizzle_a=config.qkv_swizzle_mode, swizzle_b=config.qkv_swizzle_mode, num_stages=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]

UMMA1Type​

comptime UMMA1Type = SM100TensorAccumulator[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype, DType.float32, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_depth, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, a_tmem=True, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_mma_kind, swizzle_b=config.qkv_swizzle_mode, transpose_b=False, num_stages=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_pv_stages]

Methods​

mask_status​

static def mask_status(mask: MaskType, seq_id: UInt32, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

descriptor_q​

static def descriptor_q(q_smem: UnsafePointer[Scalar[Self.qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

descriptor_q_rope​

static def descriptor_q_rope(q_smem: UnsafePointer[Scalar[Self.rope_mma_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair