IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100TensorAccumulator

struct SM100TensorAccumulator[operand_type: DType, accum_dtype: DType, MMA_M: Int, MMA_N: Int, BK: Int, *, a_tmem: Bool, mma_kind: UMMAKind = UMMAKind.KIND_F16, swizzle_a: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = Int(1), num_stages: Int = Int(1)]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

a_bmn​

comptime a_bmn = align_up((MMA_M // cta_group), Int(8))

a_layout​

comptime a_layout = tile_layout_k_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, align_up((MMA_M // cta_group), Int(8)), SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_a]()

accum_t​

comptime accum_t = accum_dtype

AInput​

comptime AInput = UInt32 if a_tmem else MMASmemDescriptorPair

AType​

comptime AType = TMemTile[operand_type, MMA_M, BK] if a_tmem else MMASmemDescriptorPair

b_bmn​

comptime b_bmn = MMA_N if a_tmem else (MMA_N // cta_group)

b_layout​

comptime b_layout = tile_layout_k_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, MMA_N if a_tmem else (MMA_N // cta_group), SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]() if transpose_b else tile_layout_mn_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, MMA_N if a_tmem else (MMA_N // cta_group), SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]()

BType​

comptime BType = MMASmemDescriptorPair

CType​

comptime CType = TMemTile[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, MMA_M, MMA_N]

idesc​

comptime idesc = UMMAInsDescriptor.create[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()

MMA_K​

comptime MMA_K = Int(16) if operand_type.is_half_float() else Int(32)

num_k_blocks​

comptime num_k_blocks = (SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32))

num_k_blocks_per_stage​

comptime num_k_blocks_per_stage = ((BK if a_tmem else align_up(BK, (swizzle_b.bytes() if a_tmem else max(swizzle_a.bytes(), swizzle_b.bytes()) // size_of[operand_type]())) // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32)) // Int(4) if SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].use_3_then_1_split else num_stages)

num_k_mmas​

comptime num_k_mmas = (BK // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32)) if a_tmem else ceildiv(BK, Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32))

operand_size​

comptime operand_size = size_of[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t]()

operand_t​

comptime operand_t = operand_type

padded_BK​

comptime padded_BK = BK if a_tmem else align_up(BK, (swizzle_b.bytes() if a_tmem else max(swizzle_a.bytes(), swizzle_b.bytes()) // size_of[operand_type]()))

swizzle_granularity​

comptime swizzle_granularity = (swizzle_b.bytes() if a_tmem else max(swizzle_a.bytes(), swizzle_b.bytes()) // size_of[operand_type]())

tcgen05_mma_type​

comptime tcgen05_mma_type = "tcgen05.mma.ws.cta_group::1."

use_3_then_1_split​

comptime use_3_then_1_split = a_tmem and (num_stages == Int(2)) and (((BK if a_tmem else align_up(BK, (swizzle_b.bytes() if a_tmem else max(swizzle_a.bytes(), swizzle_b.bytes()) // size_of[operand_type]())) // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 80) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> operand_type, "_mlir_value">>, 79) else Int(32)) % Int(4)) == Int(0))

use_ws​

comptime use_ws = (MMA_M <= Int(64)) if (cta_group == Int(1)) else (cta_group == Int(1))

Methods​

mma​

static def mma[*, stage_idx: Int = Int(0)](a: Self.AInput, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)

mma_maybe_partial_k​

static def mma_maybe_partial_k[*, stage_idx: Int = Int(0)](a: Self.AInput, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32, valid_k_mmas: UInt32)