Mojo struct
SM100TensorAccumulatorSS
struct SM100TensorAccumulatorSS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, compute_BK: Int, num_softmax_threads: Int, swizzle_a: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, *, transpose_b: Bool = True, cta_group: Int = 1, pipeline_stages: Int = 1]
Fieldsβ
- βmbar (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): - βpipeline (
PipelineState[pipeline_stages]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
a_offsetβ
comptime a_offset = MMAOperandOffsetFn()
a_tβ
comptime a_t = MMASmemDescriptor
ab_tβ
comptime ab_t = UMMADescriptorSS[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t]
accum_tβ
comptime accum_t = accum_type
b_offsetβ
comptime b_offset = MMAOperandOffsetFn()
b_tβ
comptime b_t = MMASmemDescriptor
c_tβ
comptime c_t = TMemAccumulator[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, (BM // SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_n_mmas, num_softmax_threads]
idescβ
comptime idesc = UMMAInsDescriptor.create[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()
MMA_Kβ
comptime MMA_K = 16
num_k_mmasβ
comptime num_k_mmas = (compute_BK // 16)
num_m_blocks_per_warpβ
comptime num_m_blocks_per_warp = ((2 * BM) // num_softmax_threads)
num_m_mmasβ
comptime num_m_mmas = (BM // MMA_M)
num_n_mmasβ
comptime num_n_mmas = (BN // MMA_N)
operand_tβ
comptime operand_t = operand_type
smem_ptr_tβ
comptime smem_ptr_t = UnsafePointer[Scalar[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t], MutAnyOrigin, address_space=AddressSpace.SHARED]
Methodsβ
__init__β
__init__(smem: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
check_constraintsβ
static check_constraints()
initβ
init(self)
mma_descriptorsβ
static mma_descriptors[dtype_a: DType, dtype_b: DType](p_a: UnsafePointer[Scalar[dtype_a], MutAnyOrigin, address_space=AddressSpace.SHARED], p_b: UnsafePointer[Scalar[dtype_b], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].ab_t
Returns:
SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].ab_t
mmaβ
mma(mut self, a: MMASmemDescriptor, b: MMASmemDescriptor, c_base: TMemAccumulator[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, (BM // SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_n_mmas, num_softmax_threads], scale_c: UInt32)
wait_for_tmemβ
wait_for_tmem(self)
Wait for the accumulator tmem to finish being read.
wait_for_mmaβ
wait_for_mma(self, c_base: TMemAccumulator[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, (BM // SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_n_mmas, num_softmax_threads]) -> SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].c_t
Wait for the accumulator tmem to finish being read.
Returns:
SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].c_t
tmem_arrive_initβ
tmem_arrive_init(self)
tmem_arriveβ
tmem_arrive(mut self)
Indicate that the accumulator is ready to be updated.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!