Skip to main content

Mojo struct

SM100MHA2Q

@register_passable(trivial) struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, SchedulerType: MHATileScheduler, config: FA4Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

accum_type

alias accum_type = get_accum_type[KVLUTType.dtype]()

BM

alias BM = config.BM

BN

alias BN = config.BN

cta_group

alias cta_group = 1

depth

alias depth = config.depth

group

alias group = config.group

k_bytes

alias k_bytes = SIMD[DType.uint32, 1](((config.swizzle_mode.bytes() // KVLUTType.dtype.size_of()) * config)).__rmul__[DType.uint32, 1](SIMD[DType.uint32, 1](KVLUTType.dtype.size_of()))

k_elements

alias k_elements = SIMD[DType.uint32, 1](((config.swizzle_mode.bytes() // KVLUTType.dtype.size_of()) * config))

KVPipelineType

alias KVPipelineType = KVPipeline[config.num_kv_stages, config.num_mma_stages]

MMA_K

alias MMA_K = 16

MMA_M

alias MMA_M = (config // 2)

num_m_mmas

alias num_m_mmas = 2

num_mma_stages

alias num_mma_stages = config.num_mma_stages

num_q_heads

alias num_q_heads = config.num_q_heads

OPipelineType

alias OPipelineType = MBarPipeline[2]

padded_depth

alias padded_depth = config.padded_depth

PositionType

alias PositionType = MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_dt_size

alias qkv_dt_size = KVLUTType.dtype.size_of()

qkv_type

alias qkv_type = KVLUTType.dtype

qo_bytes

alias qo_bytes = SIMD[DType.uint32, 1]((KVLUTType.dtype.size_of() * (config * (config // 2))))

qo_elements

alias qo_elements = (config * (config // 2))

ragged

alias ragged = ValidLengthType.is_null.__bool__().__invert__()

simd_size

alias simd_size = simd_width_of[KVLUTType.dtype]()

swizzle_granularity

alias swizzle_granularity = (config.swizzle_mode.bytes() // KVLUTType.dtype.size_of())

UMMA0Type

alias UMMA0Type = SM100TensorAccumulatorSS[KVLUTType.dtype, get_accum_type[KVLUTType.dtype](), (config // 2), config.BN, config.depth, swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, num_stages=config.num_mma_stages]

UMMA1Type

alias UMMA1Type = SM100TensorAccumulatorTS[KVLUTType.dtype, get_accum_type[KVLUTType.dtype](), (config // 2), config.padded_depth, config.BN, config.swizzle_mode, transpose_b=False, num_stages=config.num_mma_stages]

v_bytes_per_mma

alias v_bytes_per_mma = SIMD[DType.uint32, 1](((KVLUTType.dtype.size_of() * 16) * config))

Methods

kernel

static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, (config // 2), config.BK0, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64]((config // 2), config.BK0, Tuple[]()), swizzle_mode=config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, config.BN, config.padded_depth, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), swizzle_mode=config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_mn_major[KVLUTType.dtype, config.padded_depth, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), False, config.swizzle_mode](), False], o_ptr_arg: UnsafePointer[Scalar[output_type]], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, ScoreModType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])

softmax

static softmax(tmem_addr: UInt32, warp_group_idx: UInt32, mbars: FA4MiscMBars, o_mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], tid: UInt32, mask: MaskType, kv_tile_start_row: UInt32, end: UInt32, scale: Scalar[get_accum_type[KVLUTType.dtype]()], score_mod: ScoreModType, max_seq_len: UInt32, o_ptr_arg: UnsafePointer[Scalar[output_type]], o_smem: UnsafePointer[Scalar[output_type], address_space=AddressSpace(3)], sink_weights: SinkType)

correction

static correction(tmem_addr: UInt32, mbars: FA4MiscMBars, o_mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], kv_tile_start_row: UInt32, end: UInt32, mask: MaskType)

load

static load(mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], kv_tile_start_row: UInt32, end: UInt32, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, (config // 2), config.BK0, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64]((config // 2), config.BK0, Tuple[]()), swizzle_mode=config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, config.BN, config.padded_depth, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), swizzle_mode=config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, tile_layout_mn_major[KVLUTType.dtype, config.padded_depth, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BN, config.padded_depth, Tuple[]()), False, config.swizzle_mode](), False], kv_lut: KVLUTType, q_smem: UnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace(3)])

descriptor_q

static descriptor_q(q_smem: UnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace(3)]) -> MMASmemDescriptor

Returns:

MMASmemDescriptor

mma

static mma(tmem_addr: UInt32, mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], o_mbar: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3)], position: MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()], kv_tile_start_row: UInt32, end: UInt32, mask: MaskType, q_smem: UnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace(3)])

Was this page helpful?