Mojo struct
SM100TensorAccumulatorTS
struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = 1]
Fieldsβ
- βmbar (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): - βphase (
UInt32):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
a_frag_sizeβ
comptime a_frag_size = ((MMA_M * 16) // num_softmax_threads)
a_tβ
comptime a_t = TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads]
ab_tβ
comptime ab_t = UMMADescriptorTS[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, MMA_M=(BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N=BK, MMA_K=16, consumer_group_size=num_softmax_threads]
accum_tβ
comptime accum_t = accum_type
b_offsetβ
comptime b_offset = MMAOperandOffsetFn()
b_tβ
comptime b_t = MMASmemDescriptor
c_frag_sizeβ
comptime c_frag_size = ((MMA_M * MMA_N) // num_softmax_threads)
c_tβ
comptime c_t = TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads]
idescβ
comptime idesc = UMMAInsDescriptor.create[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()
MMA_Kβ
comptime MMA_K = 16
num_k_mmasβ
comptime num_k_mmas = (BK // 16)
num_m_blocks_per_warpβ
comptime num_m_blocks_per_warp = ((2 * BM) // num_softmax_threads)
num_m_mmasβ
comptime num_m_mmas = (BM // MMA_M)
num_n_mmasβ
comptime num_n_mmas = (BN // MMA_N)
operand_tβ
comptime operand_t = operand_type
smem_ptr_tβ
comptime smem_ptr_t = UnsafePointer[Scalar[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t], MutAnyOrigin, address_space=AddressSpace.SHARED]
Methodsβ
__init__β
__init__(smem: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
check_constraintsβ
static check_constraints()
initβ
init(self)
a_mma_descriptorβ
static a_mma_descriptor(a_tmem: UInt32) -> SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.a_t
Returns:
SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.a_t
b_mma_descriptorβ
static b_mma_descriptor[dtype_b: DType](p_b: UnsafePointer[Scalar[dtype_b], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.b_t
Returns:
SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.b_t
mmaβ
mma(self, a: TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads], c_scale: UInt32)
waitβ
wait(mut self, idx: UInt32)
wait_for_mmaβ
wait_for_mma(mut self)
Wait for the mma to be complete.
wait_for_tmemβ
wait_for_tmem(mut self)
Wait for the output and A tmem to be ready.
tmem_arriveβ
tmem_arrive(self)
Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!