For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SM100TensorAccumulator
struct SM100TensorAccumulator[operand_type: DType, accum_dtype: DType, MMA_M: Int, MMA_N: Int, BK: Int, *, a_tmem: Bool, mma_kind: UMMAKind = UMMAKind.KIND_F16, swizzle_a: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = 1, num_stages: Int = 1]
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
a_bmnβ
comptime a_bmn = align_up((MMA_M // cta_group), 8)
a_layoutβ
comptime a_layout = tile_layout_k_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].a_bmn, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_a]()
accum_tβ
comptime accum_t = accum_dtype
AInputβ
comptime AInput = UInt32 if a_tmem else MMASmemDescriptorPair
ATypeβ
comptime AType = TMemTile[operand_type, MMA_M, BK] if a_tmem else MMASmemDescriptorPair
b_bmnβ
comptime b_bmn = MMA_N if a_tmem else (MMA_N // cta_group)
b_layoutβ
comptime b_layout = tile_layout_k_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].b_bmn, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]() if transpose_b else tile_layout_mn_major[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].b_bmn, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]()
BTypeβ
comptime BType = MMASmemDescriptorPair
CTypeβ
comptime CType = TMemTile[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, MMA_M, MMA_N]
idescβ
comptime idesc = UMMAInsDescriptor.create[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()
MMA_Kβ
comptime MMA_K = 16 if operand_type.is_half_float() else 32
num_k_blocksβ
comptime num_k_blocks = (SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK // SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].MMA_K)
num_k_blocks_per_stageβ
comptime num_k_blocks_per_stage = (SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].num_k_blocks // 4 if SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].use_3_then_1_split else num_stages)
num_k_mmasβ
comptime num_k_mmas = (BK // SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].MMA_K) if a_tmem else ceildiv(BK, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].MMA_K)
operand_sizeβ
comptime operand_size = size_of[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t]()
operand_tβ
comptime operand_t = operand_type
padded_BKβ
comptime padded_BK = BK if a_tmem else align_up(BK, SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].swizzle_granularity)
swizzle_granularityβ
comptime swizzle_granularity = (swizzle_b.bytes() if a_tmem else max(swizzle_a.bytes(), swizzle_b.bytes()) // size_of[SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t]())
tcgen05_mma_typeβ
comptime tcgen05_mma_type = "tcgen05.mma.ws.cta_group::1."
use_3_then_1_splitβ
comptime use_3_then_1_split = a_tmem and (num_stages == 2) and ((SM100TensorAccumulator[operand_type, accum_dtype, MMA_M, MMA_N, BK, a_tmem=a_tmem, mma_kind=mma_kind, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].num_k_blocks % 4) == 0)
use_wsβ
comptime use_ws = (MMA_M <= 64) if (cta_group == 1) else (cta_group == 1)
Methodsβ
mmaβ
static def mma[*, stage_idx: Int = 0](a: Self.AInput, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)
mma_maybe_partial_kβ
static def mma_maybe_partial_k[*, stage_idx: Int = 0](a: Self.AInput, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32, valid_k_mmas: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!