Skip to main content

Mojo struct

SM100MHA2Q

struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: FA4Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

accum_type

comptime accum_type = DType.float32

BM

comptime BM = config.BM

BN

comptime BN = config.BN

cta_group

comptime cta_group = 1

depth

comptime depth = config.qk_depth

group

comptime group = config.group

HalfBM

comptime HalfBM = (SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM // 2)

k_bytes

comptime k_bytes = (SIMD(SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size) * UInt32(#pop.cast<#pop.cast_from_builtin<#kgen.param.expr<mul, #lit.struct.extract<:!lit.struct<@std::@builtin::@int::@Int> apply(:!lit.generator<("self": !lit.struct<@std::@builtin::@int::@Int>, "rhs": !lit.struct<@std::@builtin::@int::@Int>) -> !lit.struct<@std::@builtin::@int::@Int>> @std::@builtin::@int::@Int::@"__floordiv__(::Int,::Int)", apply(:!lit.generator<("self": !lit.struct<@std::@gpu::@host::@nvidia::@tma::@TensorMapSwizzle>) -> !lit.struct<@std::@builtin::@int::@Int>> @std::@gpu::@host::@nvidia::@tma::@TensorMapSwizzle::@"bytes(::TensorMapSwizzle)", #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@nvidia::@sm100::@attention::@FA4Config<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}>> config, "swizzle_mode">), apply(:!lit.generator<() -> !lit.struct<@std::@builtin::@int::@Int>> @std::@sys::@info::@"size_of[::DType,!kgen.target]()"<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, :target apply(:!lit.generator<() -> !kgen.target> @std::@sys::@info::@"_current_target()")>)), "_mlir_value"> : index, #lit.struct.extract<:!lit.struct<@std::@builtin::@int::@Int> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@nvidia::@sm100::@attention::@FA4Config<:!lit.struct<@std::@builtin::@dtype::@DType> #kgen.get_witness<:trait<@nn::@attention::@mha_operand::@MHAOperand> KVLUTType, "nn::attention::mha_operand::MHAOperand", "dtype">, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}, :!lit.struct<@std::@builtin::@dtype::@DType> {:dtype invalid}>> config, "BN">, "_mlir_value"> : index> : index> : !pop.scalar<index>>))

k_elements

comptime k_elements = SIMD((SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].swizzle_granularity * config))

MiscMBarsType

comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering, use_fused_kv=config.use_fused_kv]

MMA_K

comptime MMA_K = 16

MMA_M

comptime MMA_M = (config // 2)

num_m_mmas

comptime num_m_mmas = 2

num_pv_stages

comptime num_pv_stages = config.num_pv_stages

num_q_heads

comptime num_q_heads = config.num_q_heads

num_qk_stages

comptime num_qk_stages = config.num_qk_stages

padded_depth

comptime padded_depth = config.padded_qk_depth

page_size

comptime page_size = KVLUTType.page_size

PositionType

comptime PositionType = MHAPosition[config.BM, config.BN, config.qk_depth, config.padded_qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_dt_size

comptime qkv_dt_size = size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()

qkv_type

comptime qkv_type = KVLUTType.dtype

qo_bytes

comptime qo_bytes = SIMD((SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qo_elements))

qo_elements

comptime qo_elements = (SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].padded_depth * SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M)

ragged

comptime ragged = not ValidLengthType.is_null.__bool__()

simd_size

comptime simd_size = simd_width_of[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()

SmemType

comptime SmemType = SM100AttentionSMem[config]

swizzle_granularity

comptime swizzle_granularity = (config.swizzle_mode.bytes() // SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size)

TmemAllocType

comptime TmemAllocType = TmemAllocation[1]

UMMA0Type

comptime UMMA0Type = SM100TensorAccumulatorSS[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, align_up(SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].depth, 16), swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages]

UMMA1Type

comptime UMMA1Type = SM100TensorAccumulatorTS[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, config.padded_ov_depth, SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, config.swizzle_mode, transpose_b=False, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages]

v_bytes_per_mma

comptime v_bytes_per_mma = SIMD(((SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * 16) * config))

Methods

kernel

static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.BK0), Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.BK0), Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.padded_ov_depth), Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.padded_ov_depth), Tuple()), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), config.ov_depth], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])

mask_status

static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

descriptor_q

static descriptor_q(q_smem: UnsafePointer[Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

Was this page helpful?