Skip to main content

Mojo struct

SM100MLA

struct SM100MLA[KVLUTType: MHAOperand, KRopeType: MHAOperand, output_dtype: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: MLAConfig[config.qkv_dtype, rope_gmem_dtype=config.rope_gmem_dtype, rope_mma_dtype=config.rope_mma_dtype, scale_dtype=config.scale_dtype], ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, _ndbuffer_mha_operand: Bool]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

accum_dtype

comptime accum_dtype = DType.float32

BK0

comptime BK0 = SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qk_depth if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].fused_umma0 else SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_depth

BM

comptime BM = config.BM

BN

comptime BN = config.BN

cache_depth

comptime cache_depth = config.cache_depth

cta_group

comptime cta_group = 1

fused_umma0

comptime fused_umma0 = not config.fa4_config.use_fused_kv.__bool__() if (SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype == SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype) else (SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype == SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype)

group

comptime group = config.group

MiscMBarsType

comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.fa4_config.num_qk_stages, num_pv_stages=config.fa4_config.num_pv_stages, num_kv_stages=config.fa4_config.num_kv_stages, use_order_barriers=EnableForcedOrdering, use_fused_kv=config.fa4_config.use_fused_kv]

MMA_M

comptime MMA_M = (config // 2)

nope_depth

comptime nope_depth = config.nope_depth

nope_mma_kind

comptime nope_mma_kind = UMMAKind.KIND_F16 if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype.is_half_float() else UMMAKind.KIND_F8F6F4

num_m_mmas

comptime num_m_mmas = 2

num_pv_stages

comptime num_pv_stages = config.num_pv_stages

num_q_heads

comptime num_q_heads = config.num_q_heads

num_qk_stages

comptime num_qk_stages = config.num_qk_stages

padded_depth

comptime padded_depth = config.padded_qk_depth

page_size

comptime page_size = KVLUTType.page_size

PositionType

comptime PositionType = MHAPosition[config.BM, config.BN, config.qk_depth, config.padded_qk_depth, config.num_q_heads, config.group, False]

q_rope_byte_offset

comptime q_rope_byte_offset = ((SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M * config.fa4_config) * size_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]())

qk_depth

comptime qk_depth = config.qk_depth

qkv_dt_size

comptime qkv_dt_size = size_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]()

qkv_dtype

comptime qkv_dtype = KVLUTType.dtype

rope_depth

comptime rope_depth = config.rope_depth

rope_gmem_dtype

comptime rope_gmem_dtype = KRopeType.dtype

rope_mma_dtype

comptime rope_mma_dtype

rope_mma_kind

comptime rope_mma_kind = UMMAKind.KIND_F16 if SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype.is_half_float() else UMMAKind.KIND_F8F6F4

simd_size

comptime simd_size = simd_width_of[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype]()

UMMA0RopeType

comptime UMMA0RopeType = SM100TensorAccumulatorSS[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype, DType.float32, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_depth, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_kind, swizzle_a=config.rope_mma_swizzle_mode, swizzle_b=config.rope_mma_swizzle_mode, num_stages=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]

UMMA0Type

comptime UMMA0Type = SM100TensorAccumulatorSS[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype, DType.float32, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BK0, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_mma_kind, swizzle_a=config.qkv_swizzle_mode, swizzle_b=config.qkv_swizzle_mode, num_stages=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_qk_stages]

UMMA1Type

comptime UMMA1Type = SM100TensorAccumulatorTS[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype, DType.float32, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].MMA_M, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_depth, SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].BN, config.qkv_swizzle_mode, mma_kind=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].nope_mma_kind, transpose_b=False, num_stages=SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].num_pv_stages]

Methods

mask_status

static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

descriptor_q

static descriptor_q(q_smem: UnsafePointer[Scalar[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

descriptor_q_rope

static descriptor_q_rope(q_smem: UnsafePointer[Scalar[SM100MLA[KVLUTType, KRopeType, output_dtype, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType, _ndbuffer_mha_operand].rope_mma_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

Was this page helpful?