IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100TensorAccumulator

struct SM100TensorAccumulator[operand_type: DType, accum_dtype: DType, MMA_M: Int, MMA_N: Int, BK: Int, *, a_tmem: Bool, mma_kind: UMMAKind = UMMAKind.KIND_F16, swizzle_a: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = 1, num_stages: Int = 1]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

a_bmn​

comptime a_bmn = align_up((MMA_M // cta_group), 8)

a_layout​

comptime a_layout = tile_layout_k_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].a_bmn, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_a]()

accum_t​

comptime accum_t = accum_dtype

AInput​

comptime AInput = UInt32 if a_tmem else MMASmemDescriptorPair

AType​

comptime AType = TMemTile[operand_type, MMA_M, BK] if a_tmem else MMASmemDescriptorPair

b_bmn​

comptime b_bmn = MMA_N if a_tmem else (MMA_N // cta_group)

b_layout​

comptime b_layout = tile_layout_k_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].b_bmn, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]() if transpose_b else tile_layout_mn_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].b_bmn, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]()

BType​

comptime BType = MMASmemDescriptorPair

CType​

comptime CType = TMemTile[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, MMA_M, MMA_N]

idesc​

comptime idesc = UMMAInsDescriptor.create[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()

MMA_K​

comptime MMA_K = 16 if operand_type.is_half_float() else 32

num_k_blocks​

comptime num_k_blocks = (SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK // SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].MMA_K)

num_k_blocks_per_stage​

comptime num_k_blocks_per_stage = (SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].num_k_blocks // 4 if SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].use_3_then_1_split else num_stages)

num_k_mmas​

comptime num_k_mmas = (BK // SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].MMA_K) if a_tmem else ceildiv(BK, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].MMA_K)

operand_size​

comptime operand_size = size_of[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t]()

operand_t​

comptime operand_t = operand_type

padded_BK​

comptime padded_BK = BK if a_tmem else align_up(BK, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].swizzle_granularity)

swizzle_granularity​

comptime swizzle_granularity = (swizzle_b.bytes() if a_tmem else max(swizzle_a.bytes(), swizzle_b.bytes()) // size_of[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t]())

tcgen05_mma_type​

comptime tcgen05_mma_type = "tcgen05.mma.ws.cta_group::1."

use_3_then_1_split​

comptime use_3_then_1_split = a_tmem and (num_stages == 2) and ((SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].num_k_blocks % 4) == 0)

use_ws​

comptime use_ws = (MMA_M <= 64) if (cta_group == 1) else (cta_group == 1)

Methods​

mma​

static def mma[*, stage_idx: Int = 0](a: Self.AInput, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)

mma_maybe_partial_k​

static def mma_maybe_partial_k[*, stage_idx: Int = 0](a: Self.AInput, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32, valid_k_mmas: UInt32)