Mojo struct
DecodeSM100QKTSS
@register_passable(trivial)
struct DecodeSM100QKTSS[operand_type: DType, accum_type: DType, *, config: MLA_SM100_Decode_Config]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
ALayout
comptime ALayout = tile_layout_k_major[operand_type, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.BM, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.BN, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.a_swizzle]()
BLayout
comptime BLayout = tile_layout_k_major[operand_type, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.BM, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.BN, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.b_swizzle]()
CTileType
comptime CTileType = TMemTile[accum_type, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.S_M, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.S_N]
TensorAccumulatorSS
comptime TensorAccumulatorSS = DecodeSM100TensorAccumulatorSS[operand_type, accum_type, config=config]
UMMAInstDesc
comptime UMMAInstDesc = UMMAInsDescriptor.create[UMMAKind.KIND_F16, accum_type, operand_type, operand_type, Index[dtype=DType.uint32](DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.MMA_M, DecodeSM100QKTSS[operand_type, accum_type, config=config].TensorAccumulatorSS.MMA_N)]()
Methods
descriptor_q_block
static descriptor_q_block(q_smem: UnsafePointer[Scalar[operand_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
descriptor_k_block
static descriptor_k_block(kv_smem: UnsafePointer[Scalar[operand_type], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair
Returns:
mma
static mma[*, stage_idx: Int = 0](a: MMASmemDescriptorPair, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!