Skip to main content

struct TmemTensor[dtype: DType, layout: Layout, *, cta_group: Int = 1]

Typed tensor view over Tensor Memory (TMEM) for MMA accumulators.

Provides a typed abstraction for TMEM with:

  • Type safety: dtype and layout known at compile time
  • Fragment access: upper (rows 0-15) and lower (rows 16-31)
  • MMA integration: offset() returns raw address for MMA operations

The layout parameter captures the logical accumulator shape (M Γ— N), enabling future extensions like custom tiling patterns or multi-tile accumulator management.

Example:

Create typed TMEM view with (64, 128) accumulator layout

comptime layout = Layout.row_major(64, 128) var tmem = TmemTensorDType.float32, layout

Use with MMA operations (returns raw UInt32 offset)

mma_op.mma(a_tile, b_tile, tmem.offset(), init_c=True)

Load fragments for epilogue

var upper = tmem.load_upperrepeat=4 var lower = tmem.load_lowerrepeat=4 TmemTensor.wait_load()

Parameters​

  • ​dtype (DType): Accumulator data type (typically float32).
  • ​layout (Layout): Logical layout of the accumulator tile (M Γ— N).
  • ​cta_group (Int): CTA cooperation level (1 or 2).

Fields​

  • ​col_addr (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

bits​

comptime bits = 256

data_paths​

comptime data_paths = 16

frag_size​

comptime frag_size = 4

Fragments​

comptime Fragments = TmemFragments[dtype, 4, is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]

is_lower_required​

comptime is_lower_required = not (TmemTensor[dtype, layout, cta_group=cta_group].tile_m == 64) if (cta_group == 1) else (cta_group == 1).__bool__()

tile_m​

comptime tile_m = layout.shape[0].value()

Methods​

__init__​

__init__(col_addr: Int) -> Self

Create TMEM tensor view at the given column address.

__init__(addr: TmemAddress) -> Self

Create TMEM tensor view from a TmemAddress.

offset​

offset(self) -> Int

TMEM column address for this tensor.

Returns:

Int

address​

address(self) -> TmemAddress

Get TmemAddress for low-level fragment operations.

Returns:

TmemAddress

load_upper​

load_upper[repeat: Int = 1](self) -> InlineArray[Scalar[dtype], (4 * repeat)]

Load upper accumulator fragment (rows 0-15).

Parameters:

  • ​repeat (Int): Number of times to repeat the load pattern.

Returns:

InlineArray[Scalar[dtype], (4 * repeat)]: InlineArray containing the upper fragment data.

load_lower​

load_lower[repeat: Int = 1](self) -> InlineArray[Scalar[dtype], (4 * repeat)]

Load lower accumulator fragment (rows 16-31).

Parameters:

  • ​repeat (Int): Number of times to repeat the load pattern.

Returns:

InlineArray[Scalar[dtype], (4 * repeat)]: InlineArray containing the lower fragment data.

store_upper​

store_upper[repeat: Int = 1](self, data: InlineArray[Scalar[dtype], (4 * repeat)])

Store upper accumulator fragment (rows 0-15).

Parameters:

  • ​repeat (Int): Number of times to repeat the store pattern.

Args:

store_lower​

store_lower[repeat: Int = 1](self, data: InlineArray[Scalar[dtype], (4 * repeat)])

Store lower accumulator fragment (rows 16-31).

Parameters:

  • ​repeat (Int): Number of times to repeat the store pattern.

Args:

load_fragments​

load_fragments[repeat: Int = 1](self) -> TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]

Load both upper and lower fragments in one call.

Handles is_lower_required automatically based on layout.

Parameters:

  • ​repeat (Int): Number of times to repeat the load pattern.

Returns:

TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]: TmemFragments containing upper and (conditionally) lower data.

store_fragments​

store_fragments[repeat: Int = 1](self, frags: TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required])

Store both upper and lower fragments in one call.

Handles is_lower_required automatically based on layout.

Parameters:

  • ​repeat (Int): Number of times to repeat the store pattern.

Args:

wait_load​

static wait_load()

Wait for TMEM load operations to complete.

wait_store​

static wait_store()

Wait for TMEM store operations to complete.