Skip to main content

Mojo struct

SM100TensorAccumulatorTS

struct SM100TensorAccumulatorTS[operand_type: DType, accum_dtype: DType, MMA_M: Int, MMA_N: Int, BK: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, *, mma_kind: UMMAKind = UMMAKind.KIND_F16, transpose_b: Bool = True, cta_group: Int = 1, num_stages: Int = 1, padded_BK: Int = BK]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_t​

comptime accum_t = accum_dtype

AType​

comptime AType = TMemTile[operand_type, MMA_M, BK]

b_layout​

comptime b_layout = tile_layout_k_major[SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, MMA_N, BK, swizzle_b]() if transpose_b else tile_layout_mn_major[SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, MMA_N, BK, swizzle_b]()

BType​

comptime BType = MMASmemDescriptorPair

CType​

comptime CType = TMemTile[SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].accum_t, MMA_M, MMA_N]

idesc​

comptime idesc = UMMAInsDescriptor.create[SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].accum_t, SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, Index[Int, Int, dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()

MMA_K​

comptime MMA_K = 16 if operand_type.is_half_float() else 32

num_k_blocks​

comptime num_k_blocks = (padded_BK // SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].MMA_K)

num_k_blocks_per_stage​

comptime num_k_blocks_per_stage = (SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].num_k_blocks // 4 if SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].use_3_then_1_split else num_stages)

num_k_mmas​

comptime num_k_mmas = (BK // SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].MMA_K)

operand_size​

comptime operand_size = size_of[operand_type]()

operand_t​

comptime operand_t = operand_type

swizzle_granularity​

comptime swizzle_granularity = (swizzle_b.bytes() // SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_size)

use_3_then_1_split​

comptime use_3_then_1_split = ((SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].num_k_blocks % 4) == 0) if (num_stages == 2) else (num_stages == 2)

Methods​

descriptor_a​

static descriptor_a(a_tmem: UInt32) -> SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].AType

Returns:

SM100TensorAccumulatorTS[operand_type, accum_dtype, MMA_M, MMA_N, BK, swizzle_b, mma_kind=mma_kind, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].AType

mma​

static mma[*, stage_idx: Int = 0](a: UInt32, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)