Skip to main content

Mojo struct

DecodeSM100QKTSS_Content_FP8

struct DecodeSM100QKTSS_Content_FP8[operand_type: DType, accum_type: DType, *, config: MLA_SM100_Decode_Config]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

BK​

comptime BK = config.padded_depth

MMA_K​

comptime MMA_K = 32

MMA_M​

comptime MMA_M = config.MMA_M

MMA_N​

comptime MMA_N = config.MMA_QK_N

num_k_mmas​

comptime num_k_mmas = (DecodeSM100QKTSS_Content_FP8[operand_type, accum_type, config=config].BK // 32)

operand_size​

comptime operand_size = size_of[operand_type]()

UMMAInstDesc​

comptime UMMAInstDesc = UMMAInsDescriptor.create[accum_type, operand_type, operand_type, Index[Int, Int, dtype=DType.uint32](DecodeSM100QKTSS_Content_FP8[operand_type, accum_type, config=config].MMA_M, DecodeSM100QKTSS_Content_FP8[operand_type, accum_type, config=config].MMA_N)]()

Methods​

descriptor_q_block​

static descriptor_q_block(q_smem: UnsafePointer[Scalar[operand_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

descriptor_k_block​

static descriptor_k_block(kv_smem: UnsafePointer[Scalar[operand_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

mma​

static mma[*, stage_idx: Int = 0](a: MMASmemDescriptorPair, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)