IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100MHA2Q

struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: FA4Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_type​

comptime accum_type = DType.float32

BM​

comptime BM = config.BM

BM_eff​

comptime BM_eff = config.BM_eff()

BM_mask​

comptime BM_mask = config.PairBM_eff()

BN​

comptime BN = config.BN

cta_group​

comptime cta_group = 2 if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].pair_cta else 1

depth​

comptime depth = config.qk_depth

fuse_gqa​

comptime fuse_gqa = config.fuse_gqa

group​

comptime group = config.group

HalfBM​

comptime HalfBM = (SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM // 2)

k_bytes​

comptime k_bytes = (SIMD(SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size) * UInt32(#pop.cast<#kgen.param.expr<mul, #kgen.cast_from_builtin<#lit.struct.extract<:!lit.struct<@std::@builtin::@int::@Int> apply(:!lit.generator<("self": !lit.struct<@std::@builtin::@int::@Int>, "rhs": !lit.struct<@std::@builtin::@int::@Int>) -> !lit.struct<@std::@builtin::@int::@Int>> @std::@builtin::@int::@Int::@"__floordiv__(::Int,::Int)", apply(:!lit.generator<("self": !lit.struct<@std::@gpu::@host::@nvidia::@tma::@TensorMapSwizzle>) -> !lit.struct<@std::@builtin::@int::@Int>> @std::@gpu::@host::@nvidia::@tma::@TensorMapSwizzle::@"bytes(::TensorMapSwizzle)", #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@nvidia::@sm100::@attention::@FA4Config<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}>> config, "swizzle_mode">), apply(:!lit.generator<() -> !lit.struct<@std::@builtin::@int::@Int>> @std::@sys::@info::@"size_of[::DType,!kgen.target]()"<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, :target apply(:!lit.generator<() -> !kgen.target> @std::@sys::@info::@"_current_target()")>)), "_mlir_value"> : index> : !kgen.scalar<index>, #kgen.cast_from_builtin<#lit.struct.extract<:!lit.struct<@std::@builtin::@int::@Int> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@nvidia::@sm100::@attention::@FA4Config<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}>> config, "BN">, "_mlir_value"> : index> : !kgen.scalar<index>> : !kgen.scalar<index>>))

k_elements​

comptime k_elements = SIMD((SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].swizzle_granularity * config))

KTMAOpType​

comptime KTMAOpType = TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode]()]

MMA_K​

comptime MMA_K = 16

MMA_M​

comptime MMA_M = config.MMA_M

num_m_mmas​

comptime num_m_mmas = 2

num_pv_stages​

comptime num_pv_stages = config.num_pv_stages

num_q_heads​

comptime num_q_heads = config.num_q_heads

num_qk_stages​

comptime num_qk_stages = config.num_qk_stages

OTMAStoreType​

comptime OTMAStoreType = RaggedTMA3DTile[output_type, config.swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 1]

PackType​

comptime PackType = Pack[MaskType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType]

padded_depth​

comptime padded_depth = config.padded_qk_depth

page_size​

comptime page_size = KVLUTType.page_size

pair_cta​

comptime pair_cta = config.pair_cta

PositionType​

comptime PositionType = MHAPosition[config.BM, config.BN, config.qk_depth, config.padded_qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_dt_size​

comptime qkv_dt_size = size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()

qkv_type​

comptime qkv_type = KVLUTType.dtype

qo_bytes​

comptime qo_bytes = SIMD((SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qo_elements))

qo_elements​

comptime qo_elements = (SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].padded_depth * SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].HalfBM)

QTMAOpType​

comptime QTMAOpType = TMATensorTile[KVLUTType.dtype, 4 if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, _padded_shape[4 if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[4 if SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()]

ragged​

comptime ragged = not ValidLengthType.is_null.__bool__()

simd_size​

comptime simd_size = simd_width_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()

SmemType​

comptime SmemType = SM100AttentionSMem[config]

swizzle_granularity​

comptime swizzle_granularity = (config.swizzle_mode.bytes() // SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size)

UMMA0Type​

comptime UMMA0Type = SM100TensorAccumulator[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, align_up(SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].depth, FA4Config[KVLUTType.dtype].MMA_K), a_tmem=False, swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, cta_group=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].cta_group, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages]

UMMA1Type​

comptime UMMA1Type = SM100TensorAccumulator[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, config.padded_ov_depth, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, a_tmem=True, swizzle_b=config.swizzle_mode, transpose_b=False, cta_group=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].cta_group, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages]

v_bytes_per_mma​

comptime v_bytes_per_mma = SIMD(((SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * 16) * config))

VTMAOpType​

comptime VTMAOpType = TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].page_size), 1, config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].page_size), 1, config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode]()]

Methods​

kernel​

static def kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, 4 if Self.fuse_gqa else 3, _padded_shape[4 if Self.fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[4 if Self.fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // config), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), Self.page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.k_rows_per_cta(), Self.page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, Self.page_size), 1, config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BN, Self.page_size), 1, config.v_cols_per_cta(), __list_literal__=NoneType(None)), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=(config // config), BN=config.ov_depth, group=config.group if Self.fuse_gqa else 1], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])

mask_status​

static def mask_status(mask: MaskType, seq_id: UInt32, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

descriptor_q​

static def descriptor_q(q_smem: UnsafePointer[Scalar[Self.qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair