IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MSASM100TensorAccumulatorTS

struct MSASM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = Int(1)]

Fields​

  • ​mbar (UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • ​phase (UInt32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

a_frag_size​

comptime a_frag_size = (Int((mul Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32), MMA_M)) // num_softmax_threads)

a_t​

comptime a_t = TMemOperand[MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32), num_softmax_threads]

ab_t​

comptime ab_t = UMMADescriptorTS[MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, MMA_M=(BM // MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N=BK, MMA_K=Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32), consumer_group_size=num_softmax_threads]

accum_t​

comptime accum_t = accum_type

b_offset​

comptime b_offset = MMAOperandOffsetFn()

b_t​

comptime b_t = MMASmemDescriptor

c_frag_size​

comptime c_frag_size = ((MMA_M * MMA_N) // num_softmax_threads)

c_t​

comptime c_t = TMemAccumulator[MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads]

idesc​

comptime idesc = UMMAInsDescriptor.create[MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()

MMA_K​

comptime MMA_K = Int(16) if MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t.is_half_float() else Int(32)

mma_kind​

comptime mma_kind = UMMAKind.KIND_F8F6F4 if MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t.is_float8() else UMMAKind.KIND_F16

num_k_mmas​

comptime num_k_mmas = (BK // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32))

num_m_blocks_per_warp​

comptime num_m_blocks_per_warp = ((Int(2) * BM) // num_softmax_threads)

num_m_mmas​

comptime num_m_mmas = (BM // MMA_M)

num_n_mmas​

comptime num_n_mmas = (BN // MMA_N)

operand_t​

comptime operand_t = operand_type

smem_ptr_t​

comptime smem_ptr_t = UnsafePointer[Scalar[MSASM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t], MutAnyOrigin, address_space=AddressSpace.SHARED]

Methods​

__init__​

def __init__(smem: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

check_constraints​

static def check_constraints()

init​

def init(self)

a_mma_descriptor​

static def a_mma_descriptor(a_tmem: UInt32) -> Self.ab_t.a_t

Returns:

Self.ab_t.a_t

b_mma_descriptor​

static def b_mma_descriptor[dtype_b: DType](p_b: UnsafePointer[Scalar[dtype_b], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self.ab_t.b_t

Returns:

Self.ab_t.b_t

mma​

def mma(self, a: TMemOperand[Self.operand_t, Self.num_m_blocks_per_warp, Self.num_n_mmas, (BM // Self.num_m_blocks_per_warp), BK, Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> operand_type, "_mlir_value">>, 79) else Int(32), num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[Self.accum_t, (BM // Self.num_m_blocks_per_warp), MMA_N, Self.num_m_blocks_per_warp, Self.num_n_mmas, num_softmax_threads], c_scale: UInt32)

wait​

def wait(mut self, idx: UInt32)

wait_for_mma​

def wait_for_mma(mut self)

Wait for the mma to be complete.

wait_for_tmem​

def wait_for_tmem(mut self)

Wait for the output and A tmem to be ready.

tmem_arrive​

def tmem_arrive(self)

Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.