IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100MHA2Q

struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: FA4Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_type​

comptime accum_type = DType.float32

BM​

comptime BM = config.BM

BM_eff​

comptime BM_eff = config.BM_eff()

BM_mask​

comptime BM_mask = config.PairBM_eff()

BN​

comptime BN = config.BN

cta_group​

comptime cta_group = Int(2) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].pair_cta else Int(1)

depth​

comptime depth = config.qk_depth

fuse_gqa​

comptime fuse_gqa = config.fuse_gqa

group​

comptime group = config.group

HalfBM​

comptime HalfBM = (config // Int(2))

k_bytes​

comptime k_bytes = (SIMD(size_of[KVLUTType.dtype]()) * SIMD(Int((mul (config.swizzle_mode.bytes() // size_of[KVLUTType.dtype]()), config.BN))))

k_elements​

comptime k_elements = SIMD(Int((mul (config.swizzle_mode.bytes() // size_of[KVLUTType.dtype]()), config.BN)))

KTMAOpType​

comptime KTMAOpType = TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode]()]

MMA_K​

comptime MMA_K = 16

MMA_M​

comptime MMA_M = config.MMA_M

num_m_mmas​

comptime num_m_mmas = 2

num_pv_stages​

comptime num_pv_stages = config.num_pv_stages

num_q_heads​

comptime num_q_heads = config.num_q_heads

num_qk_stages​

comptime num_qk_stages = config.num_qk_stages

OTMAStoreType​

comptime OTMAStoreType = RaggedTMA3DTile[output_type, config.swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if config.fuse_gqa else Int(1)]

PackType​

comptime PackType = Pack[MaskType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType]

padded_depth​

comptime padded_depth = config.padded_qk_depth

page_size​

comptime page_size = KVLUTType.page_size

pair_cta​

comptime pair_cta = config.pair_cta

PositionType​

comptime PositionType = MHAPosition[config.BM, config.BN, config.qk_depth, config.padded_qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_dt_size​

comptime qkv_dt_size = size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()

qkv_type​

comptime qkv_type = KVLUTType.dtype

qo_bytes​

comptime qo_bytes = SIMD(Int((mul (config // Int(2)), size_of[KVLUTType.dtype](), config.padded_qk_depth)))

qo_elements​

comptime qo_elements = (config * (config // Int(2)))

QTMAOpType​

comptime QTMAOpType = TMATensorTile[KVLUTType.dtype, Int(4) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else Int(3), _padded_shape[Int(4) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[Int(4) if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()]

ragged​

comptime ragged = not ValidLengthType.is_null.__bool__()

simd_size​

comptime simd_size = simd_width_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()

SmemType​

comptime SmemType = SM100AttentionSMem[config]

swizzle_granularity​

comptime swizzle_granularity = (config.swizzle_mode.bytes() // size_of[KVLUTType.dtype]())

UMMA0Type​

comptime UMMA0Type = SM100TensorAccumulator[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, config.MMA_M, config.BN, align_up(config.qk_depth, Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, "_mlir_value">>, 79) else Int(32)), a_tmem=False, swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, cta_group=Int(2) if config.pair_cta else Int(1), num_stages=config.num_qk_stages]

UMMA1Type​

comptime UMMA1Type = SM100TensorAccumulator[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, config.MMA_M, config.padded_ov_depth, config.BN, a_tmem=True, swizzle_b=config.swizzle_mode, transpose_b=False, cta_group=Int(2) if config.pair_cta else Int(1), num_stages=config.num_pv_stages]

v_bytes_per_mma​

comptime v_bytes_per_mma = SIMD(Int((mul size_of[KVLUTType.dtype](), config.padded_ov_depth, 16)))

VTMAOpType​

comptime VTMAOpType = TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode]()]

Methods​

kernel​

static def kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, Int(4) if Self.fuse_gqa else Int(3), _padded_shape[Int(4) if Self.fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[Int(4) if Self.fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, KVLUTType.page_size), Int(1), config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if config.fuse_gqa else Int(1)], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])

mask_status​

static def mask_status(mask: MaskType, seq_id: UInt32, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

descriptor_q​

static def descriptor_q(q_smem: UnsafePointer[Scalar[Self.qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair