Mojo struct
SM100TensorAccumulatorTS
@register_passable(trivial)
struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = 1]
Fields
- mbar (
LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]): - phase (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
a_frag_size
comptime a_frag_size = ((MMA_M * 16) // num_softmax_threads)
a_t
comptime a_t = TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads]
ab_t
comptime ab_t = UMMADescriptorTS[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, MMA_M=(BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N=BK, MMA_K=16, consumer_group_size=num_softmax_threads]
accum_t
comptime accum_t = accum_type
b_offset
comptime b_offset = MMAOperandOffsetFn[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, BN, BK, swizzle_b, transpose_b, MMA_N, 16]()
b_t
comptime b_t = MMASmemDescriptor
c_frag_size
comptime c_frag_size = ((MMA_M * MMA_N) // num_softmax_threads)
c_t
comptime c_t = TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads]
idesc
comptime idesc = UMMAInsDescriptor.create[UMMAKind.KIND_F16, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, Index[dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()
MMA_K
comptime MMA_K = 16
num_k_mmas
comptime num_k_mmas = (BK // 16)
num_m_blocks_per_warp
comptime num_m_blocks_per_warp = ((2 * BM) // num_softmax_threads)
num_m_mmas
comptime num_m_mmas = (BM // MMA_M)
num_n_mmas
comptime num_n_mmas = (BN // MMA_N)
operand_t
comptime operand_t = operand_type
smem_ptr_t
comptime smem_ptr_t = LegacyUnsafePointer[Scalar[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t], address_space=AddressSpace.SHARED]
Methods
__init__
__init__(smem: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self
check_constraints
static check_constraints()
init
init(self)
a_mma_descriptor
static a_mma_descriptor(a_tmem: UInt32) -> TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads]
Returns:
b_mma_descriptor
static b_mma_descriptor[dtype_b: DType](p_b: LegacyUnsafePointer[Scalar[dtype_b], address_space=AddressSpace.SHARED]) -> MMASmemDescriptor
Returns:
mma
mma(self, a: TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads], c_scale: UInt32)
wait
wait(mut self, idx: UInt32)
wait_for_mma
wait_for_mma(mut self)
Wait for the mma to be complete.
wait_for_tmem
wait_for_tmem(mut self)
Wait for the output and A tmem to be ready.
tmem_arrive
tmem_arrive(self)
Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!