Skip to main content

Mojo struct

DecodeSM100QKTSS

@register_passable(trivial) struct DecodeSM100QKTSS[operand_type: DType, accum_type: DType, *, config: MLA_SM100_Decode_Config]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, TrivialRegisterType

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

ALayout

comptime ALayout = tile_layout_k_major[operand_type, config.BM, DecodeSM100QKTSS[operand_type, accum_type, config=config].BK, config.swizzle_mode]()

BK

comptime BK = config.BK0

BLayout

comptime BLayout = tile_layout_k_major[operand_type, config.BN, DecodeSM100QKTSS[operand_type, accum_type, config=config].BK, config.kv_swizzle_mode]()

MMA_K

comptime MMA_K = MLA_SM100_Decode_Config.MMA_K

MMA_M

comptime MMA_M = config.MMA_M

MMA_N

comptime MMA_N = config.MMA_QK_N

num_k_mmas

comptime num_k_mmas = (DecodeSM100QKTSS[operand_type, accum_type, config=config].BK // 16)

operand_size

comptime operand_size = size_of[operand_type]()

UMMAInstDesc

comptime UMMAInstDesc = UMMAInsDescriptor.create[UMMAKind.KIND_F16, accum_type, operand_type, operand_type, Index[dtype=DType.uint32](DecodeSM100QKTSS[operand_type, accum_type, config=config].MMA_M, DecodeSM100QKTSS[operand_type, accum_type, config=config].MMA_N)]()

Methods

descriptor_q_block

static descriptor_q_block(q_smem: UnsafePointer[Scalar[operand_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

descriptor_k_block

static descriptor_k_block(kv_smem: UnsafePointer[Scalar[operand_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair

Returns:

MMASmemDescriptorPair

mma

static mma[*, stage_idx: Int = 0](a: MMASmemDescriptorPair, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)

Was this page helpful?