Skip to main content

Mojo struct

SM100TensorAccumulatorTS

struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = 1]

Fields​

  • ​mbar (UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • ​phase (UInt32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

a_frag_size​

comptime a_frag_size = ((MMA_M * 16) // num_softmax_threads)

a_t​

comptime a_t = TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads]

ab_t​

comptime ab_t = UMMADescriptorTS[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, MMA_M=(BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N=BK, MMA_K=16, consumer_group_size=num_softmax_threads]

accum_t​

comptime accum_t = accum_type

b_offset​

comptime b_offset = MMAOperandOffsetFn()

b_t​

comptime b_t = MMASmemDescriptor

c_frag_size​

comptime c_frag_size = ((MMA_M * MMA_N) // num_softmax_threads)

c_t​

comptime c_t = TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads]

idesc​

comptime idesc = UMMAInsDescriptor.create[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()

MMA_K​

comptime MMA_K = 16

num_k_mmas​

comptime num_k_mmas = (BK // 16)

num_m_blocks_per_warp​

comptime num_m_blocks_per_warp = ((2 * BM) // num_softmax_threads)

num_m_mmas​

comptime num_m_mmas = (BM // MMA_M)

num_n_mmas​

comptime num_n_mmas = (BN // MMA_N)

operand_t​

comptime operand_t = operand_type

smem_ptr_t​

comptime smem_ptr_t = UnsafePointer[Scalar[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t], MutAnyOrigin, address_space=AddressSpace.SHARED]

Methods​

__init__​

__init__(smem: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

check_constraints​

static check_constraints()

init​

init(self)

a_mma_descriptor​

static a_mma_descriptor(a_tmem: UInt32) -> SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.a_t

Returns:

SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.a_t

b_mma_descriptor​

static b_mma_descriptor[dtype_b: DType](p_b: UnsafePointer[Scalar[dtype_b], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.b_t

Returns:

SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.b_t

mma​

mma(self, a: TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads], c_scale: UInt32)

wait​

wait(mut self, idx: UInt32)

wait_for_mma​

wait_for_mma(mut self)

Wait for the mma to be complete.

wait_for_tmem​

wait_for_tmem(mut self)

Wait for the output and A tmem to be ready.

tmem_arrive​

tmem_arrive(self)

Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.