For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SM100MHA2Q
struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: FA4Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_typeβ
comptime accum_type = DType.float32
BMβ
comptime BM = config.BM
BM_effβ
comptime BM_eff = config.BM_eff()
BM_maskβ
comptime BM_mask = config.PairBM_eff()
BNβ
comptime BN = config.BN
cta_groupβ
comptime cta_group = Int(2) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].pair_cta else Int(1)
depthβ
comptime depth = config.qk_depth
fuse_gqaβ
comptime fuse_gqa = config.fuse_gqa
groupβ
comptime group = config.group
HalfBMβ
comptime HalfBM = (config // Int(2))
k_bytesβ
comptime k_bytes = (SIMD(size_of[KVLUTType.dtype]()) * SIMD(Int((mul (config.swizzle_mode.bytes() // size_of[KVLUTType.dtype]()), config.BN))))
k_elementsβ
comptime k_elements = SIMD(Int((mul (config.swizzle_mode.bytes() // size_of[KVLUTType.dtype]()), config.BN)))
KTMAOpTypeβ
comptime KTMAOpType = TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode]()]
MMA_Kβ
comptime MMA_K = 16
MMA_Mβ
comptime MMA_M = config.MMA_M
num_m_mmasβ
comptime num_m_mmas = 2
num_pv_stagesβ
comptime num_pv_stages = config.num_pv_stages
num_q_headsβ
comptime num_q_heads = config.num_q_heads
num_qk_stagesβ
comptime num_qk_stages = config.num_qk_stages
OTMAStoreTypeβ
comptime OTMAStoreType = RaggedTMA3DTile[output_type, config.swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if config.fuse_gqa else Int(1)]
PackTypeβ
comptime PackType = Pack[MaskType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType]
padded_depthβ
comptime padded_depth = config.padded_qk_depth
page_sizeβ
comptime page_size = KVLUTType.page_size
pair_ctaβ
comptime pair_cta = config.pair_cta
PositionTypeβ
comptime PositionType = MHAPosition[config.BM, config.BN, config.qk_depth, config.padded_qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]
qkv_dt_sizeβ
comptime qkv_dt_size = size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()
qkv_typeβ
comptime qkv_type = KVLUTType.dtype
qo_bytesβ
comptime qo_bytes = SIMD(Int((mul (config // Int(2)), size_of[KVLUTType.dtype](), config.padded_qk_depth)))
qo_elementsβ
comptime qo_elements = (config * (config // Int(2)))
QTMAOpTypeβ
comptime QTMAOpType = TMATensorTile[KVLUTType.dtype, Int(4) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else Int(3), _padded_shape[Int(4) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[Int(4) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()]
raggedβ
comptime ragged = not ValidLengthType.is_null.__bool__()
simd_sizeβ
comptime simd_size = simd_width_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()
SmemTypeβ
comptime SmemType = SM100AttentionSMem[config]
swizzle_granularityβ
comptime swizzle_granularity = (config.swizzle_mode.bytes() // size_of[KVLUTType.dtype]())
UMMA0Typeβ
comptime UMMA0Type = SM100TensorAccumulator[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, config.MMA_M, config.BN, align_up(config.qk_depth, Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, 79) else Int(32)), a_tmem=False, swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, cta_group=Int(2) if config.pair_cta else Int(1), num_stages=config.num_qk_stages]
UMMA1Typeβ
comptime UMMA1Type = SM100TensorAccumulator[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, config.MMA_M, config.padded_ov_depth, config.BN, a_tmem=True, swizzle_b=config.swizzle_mode, transpose_b=False, cta_group=Int(2) if config.pair_cta else Int(1), num_stages=config.num_pv_stages]
v_bytes_per_mmaβ
comptime v_bytes_per_mma = SIMD(Int((mul size_of[KVLUTType.dtype](), config.padded_ov_depth, 16)))
VTMAOpTypeβ
comptime VTMAOpType = TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode]()]
Methodsβ
kernelβ
static def kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, Int(4) if Self.fuse_gqa else Int(3), _padded_shape[Int(4) if Self.fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[Int(4) if Self.fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if config.fuse_gqa else Int(1)], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])
mask_statusβ
static def mask_status(mask: MaskType, seq_id: UInt32, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus
Returns:
descriptor_qβ
static def descriptor_q(q_smem: UnsafePointer[Scalar[Self.qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!