IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MSASM100TensorAccumulatorSS

struct MSASM100TensorAccumulatorSS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, compute_BK: Int, num_softmax_threads: Int, swizzle_a: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, *, transpose_b: Bool = True, cta_group: Int = Int(1), pipeline_stages: Int = Int(1)]

Fields​

  • ​mbar (UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • ​pipeline (PipelineState[pipeline_stages]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

a_offset​

comptime a_offset = MMAOperandOffsetFn()

a_t​

comptime a_t = MMASmemDescriptor

ab_t​

comptime ab_t = UMMADescriptorSS[MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t]

accum_t​

comptime accum_t = accum_type

b_offset​

comptime b_offset = MMAOperandOffsetFn()

b_t​

comptime b_t = MMASmemDescriptor

c_t​

comptime c_t = TMemAccumulator[MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, (BM // MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp), MMA_N, MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp, MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_n_mmas, num_softmax_threads]

idesc​

comptime idesc = UMMAInsDescriptor.create[MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()

MMA_K​

comptime MMA_K = Int(16) if MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t.is_half_float() else Int(32)

mma_kind​

comptime mma_kind = UMMAKind.KIND_F8F6F4 if MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t.is_float8() else UMMAKind.KIND_F16

num_k_mmas​

comptime num_k_mmas = (compute_BK // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32))

num_m_blocks_per_warp​

comptime num_m_blocks_per_warp = ((Int(2) * BM) // num_softmax_threads)

num_m_mmas​

comptime num_m_mmas = (BM // MMA_M)

num_n_mmas​

comptime num_n_mmas = (BN // MMA_N)

operand_t​

comptime operand_t = operand_type

smem_ptr_t​

comptime smem_ptr_t = UnsafePointer[Scalar[MSASM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t], MutAnyOrigin, address_space=AddressSpace.SHARED]

Methods​

__init__​

def __init__(smem: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

check_constraints​

static def check_constraints()

init​

def init(self)

mma_descriptors​

static def mma_descriptors[dtype_a: DType, dtype_b: DType](p_a: UnsafePointer[Scalar[dtype_a], MutAnyOrigin, address_space=AddressSpace.SHARED], p_b: UnsafePointer[Scalar[dtype_b], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self.ab_t

Returns:

Self.ab_t

mma​

def mma(mut self, a: MMASmemDescriptor, b: MMASmemDescriptor, c_base: TMemAccumulator[Self.accum_t, (BM // Self.num_m_blocks_per_warp), MMA_N, Self.num_m_blocks_per_warp, Self.num_n_mmas, num_softmax_threads], scale_c: UInt32)

wait_for_tmem​

def wait_for_tmem(self)

Wait for the accumulator tmem to finish being read.

wait_for_mma​

def wait_for_mma(self, c_base: TMemAccumulator[Self.accum_t, (BM // Self.num_m_blocks_per_warp), MMA_N, Self.num_m_blocks_per_warp, Self.num_n_mmas, num_softmax_threads]) -> Self.c_t

Wait for the accumulator tmem to finish being read.

Returns:

Self.c_t

tmem_arrive_init​

def tmem_arrive_init(self)

tmem_arrive​

def tmem_arrive(mut self)

Indicate that the accumulator is ready to be updated.