Mojo struct
SM100MHA2Q
@register_passable(trivial)
struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, SchedulerType: MHATileScheduler, config: FA4Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
accum_type
comptime accum_type = DType.float32
BM
comptime BM = config.BM
BN
comptime BN = config.BN
correction_offset
comptime correction_offset = ((((0 + SIMD[DType.int32, 1]((config * config))) + SIMD[DType.int32, 1]((((2 * config) * config) * config))) * SIMD[DType.int32, 1](size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]())) // SIMD[DType.int32, 1](size_of[DType.float32]()))
cta_group
comptime cta_group = 1
depth
comptime depth = config.depth
group
comptime group = config.group
HalfBM
comptime HalfBM = (SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM // 2)
k_bytes
comptime k_bytes = (SIMD[DType.uint32, 1](SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size) * SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].swizzle_granularity * config)))
k_elements
comptime k_elements = SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].swizzle_granularity * config))
kv_offset
comptime kv_offset = (0 + SIMD[DType.int32, 1]((config * config)))
mbar_offset
comptime mbar_offset = (((((((0 + SIMD[DType.int32, 1]((config * config))) + SIMD[DType.int32, 1]((((2 * config) * config) * config))) * SIMD[DType.int32, 1](size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]())) // SIMD[DType.int32, 1](size_of[DType.float32]())) + SIMD[DType.int32, 1](config)) * SIMD[DType.int32, 1](size_of[DType.float32]())) // SIMD[DType.int32, 1](size_of[SharedMemBarrier]()))
MiscMBarsType
comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering]
MMA_K
comptime MMA_K = 16
MMA_M
comptime MMA_M = (config // 2)
num_m_mmas
comptime num_m_mmas = 2
num_pv_stages
comptime num_pv_stages = config.num_pv_stages
num_q_heads
comptime num_q_heads = config.num_q_heads
num_qk_stages
comptime num_qk_stages = config.num_qk_stages
padded_depth
comptime padded_depth = config.padded_depth
page_size
comptime page_size = KVLUTType.page_size
PositionType
comptime PositionType = MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]
q_offset
comptime q_offset = 0
qkv_dt_size
comptime qkv_dt_size = size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()
qkv_type
comptime qkv_type = KVLUTType.dtype
qo_bytes
comptime qo_bytes = SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qo_elements))
qo_elements
comptime qo_elements = (SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].padded_depth * SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M)
ragged
comptime ragged = ValidLengthType.is_null.__bool__().__invert__()
simd_size
comptime simd_size = simd_width_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type]()
swizzle_granularity
comptime swizzle_granularity = (config.swizzle_mode.bytes() // SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size)
UMMA0Type
comptime UMMA0Type = SM100TensorAccumulatorSS[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, align_up(SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].depth, 16), swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages]
UMMA1Type
comptime UMMA1Type = SM100TensorAccumulatorTS[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type, DType.float32, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].MMA_M, config.padded_depth, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN, config.swizzle_mode, transpose_b=False, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages]
v_bytes_per_mma
comptime v_bytes_per_mma = SIMD[DType.uint32, 1](((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_dt_size * 16) * config))
Methods
get_tmem_ptr
static get_tmem_ptr(misc_mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering]) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Returns:
get_q_smem
static get_q_smem(misc_mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering]) -> UnsafePointer[Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
Returns:
get_kv_smem
static get_kv_smem(misc_mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering]) -> UnsafePointer[Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
Returns:
get_correction_smem
static get_correction_smem(misc_mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering]) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Returns:
kernel
static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False, num_qk_stages=config.num_qk_stages](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False, num_qk_stages=config.num_qk_stages](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.BK0, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.BK0, Tuple[]()), config)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.padded_depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.padded_depth, Tuple[]()), config)], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), config.depth], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, ScoreModType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])
mask_status
static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus
Returns:
TileMaskStatus
scale_write_output
static scale_write_output(local_row: UInt32, local_warp_idx: UInt32, warp_group_idx: UInt32, inv_row_sum: Float32, o_smem_arg: UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED], o_tmem: TMemTile[DType.float32, (SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM // 2), SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].padded_depth], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), config.depth], consumer_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], num_output_rows: Int32, out_head_idx: UInt32, out_row_idx: UInt32)
softmax
static softmax(mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering], score_row: UInt32, seq_info: SeqInfo, mask: MaskType, num_keys: UInt32, scale: Float32, score_mod: ScoreModType, max_seq_len: UInt32, ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, (config // 2), config.depth], sink_weights: SinkType)
correction
static correction(mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering], score_row: UInt32, num_keys: UInt32, mask: MaskType)
load
static load(mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False, num_qk_stages=config.num_qk_stages](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False, num_qk_stages=config.num_qk_stages](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.BK0, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.BK0, Tuple[]()), config)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.padded_depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.padded_depth, Tuple[]()), config)], kv_lut: KVLUTType)
descriptor_q
static descriptor_q(q_smem: UnsafePointer[Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].qkv_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
mma
static mma(mbars: FA4MiscMBars[num_qk_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_qk_stages, num_pv_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering], score_row: UInt32, num_keys: UInt32, mask: MaskType)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!