Mojo struct
SM100MHA2Q
@register_passable(trivial)
struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, SchedulerType: MHATileScheduler, config: FA4Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
accum_type
alias accum_type = get_accum_type[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()
BM
alias BM = config.BM
BN
alias BN = config.BN
cta_group
alias cta_group = 1
depth
alias depth = config.depth
group
alias group = config.group
k_bytes
alias k_bytes = SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].k_elements.__rmul__[DType.uint32, 1](SIMD[DType.uint32, 1](SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size))
k_elements
alias k_elements = SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].swizzle_granularity * config))
KVPipelineType
alias KVPipelineType = KVPipeline[config.num_kv_stages, config.num_mma_stages]
MMA_K
alias MMA_K = 16
MMA_M
alias MMA_M = (config // 2)
num_m_mmas
alias num_m_mmas = 2
num_mma_stages
alias num_mma_stages = config.num_mma_stages
num_q_heads
alias num_q_heads = config.num_q_heads
OPipelineType
alias OPipelineType = MBarPipeline[2]
padded_depth
alias padded_depth = config.padded_depth
PositionType
alias PositionType = MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]
qkv_dt_size
alias qkv_dt_size = size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()
qkv_type
alias qkv_type = KVLUTType.dtype
qo_bytes
alias qo_bytes = SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qo_elements))
qo_elements
alias qo_elements = (SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].padded_depth * SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M)
ragged
alias ragged = ValidLengthType.is_null.__bool__().__invert__()
simd_size
alias simd_size = simd_width_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()
swizzle_granularity
alias swizzle_granularity = (config.swizzle_mode.bytes() // SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size)
UMMA0Type
alias UMMA0Type = SM100TensorAccumulatorSS[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].accum_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].depth, swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_mma_stages]
UMMA1Type
alias UMMA1Type = SM100TensorAccumulatorTS[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].accum_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, config.padded_depth, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, config.swizzle_mode, transpose_b=False, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_mma_stages]
v_bytes_per_mma
alias v_bytes_per_mma = SIMD[DType.uint32, 1](((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * 16) * config))
Methods
kernel
static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, (config // 2), config.BK0, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64]((config // 2), config.BK0, Tuple[]()), swizzle_mode=config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, config.BN, config.padded_depth, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), swizzle_mode=config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_mn_major[KVLUTType.dtype, config.padded_depth, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), False, config.swizzle_mode](), False], o_ptr_arg: LegacyUnsafePointer[Scalar[output_type]], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, ScoreModType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])
softmax
static softmax(tmem_addr: UInt32, warp_group_idx: UInt32, mbars: FA4MiscMBars, o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], tid: UInt32, mask: MaskType, kv_tile_start_row: UInt32, end: UInt32, scale: Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].accum_type], score_mod: ScoreModType, max_seq_len: UInt32, o_ptr_arg: LegacyUnsafePointer[Scalar[output_type]], o_smem: LegacyUnsafePointer[Scalar[output_type], address_space=AddressSpace.SHARED], sink_weights: SinkType)
correction
static correction(tmem_addr: UInt32, mbars: FA4MiscMBars, o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], kv_tile_start_row: UInt32, end: UInt32, mask: MaskType)
load
static load(mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], kv_tile_start_row: UInt32, end: UInt32, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, (config // 2), config.BK0, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64]((config // 2), config.BK0, Tuple[]()), swizzle_mode=config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, config.BN, config.padded_depth, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), swizzle_mode=config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_mn_major[KVLUTType.dtype, config.padded_depth, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), False, config.swizzle_mode](), False], kv_lut: KVLUTType, q_smem: LegacyUnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace.SHARED])
descriptor_q
static descriptor_q(q_smem: LegacyUnsafePointer[Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type], address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
mma
static mma(tmem_addr: UInt32, mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], kv_tile_start_row: UInt32, end: UInt32, mask: MaskType, q_smem: LegacyUnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace.SHARED])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!