Skip to main content

@register_passable(trivial) struct TmemTensor[dtype: DType, layout: Layout, *, cta_group: Int = 1]

Typed tensor view over Tensor Memory (TMEM) for MMA accumulators.

Provides a LayoutTensor-like abstraction for TMEM with:

  • Type safety: dtype and layout known at compile time
  • Fragment access: upper (rows 0-15) and lower (rows 16-31)
  • MMA integration: offset() returns raw address for MMA operations

The layout parameter captures the logical accumulator shape (M × N), enabling future extensions like custom tiling patterns or multi-tile accumulator management.

Example:

Create typed TMEM view with (64, 128) accumulator layout

comptime layout = Layout.row_major(64, 128) var tmem = TmemTensorDType.float32, layout

Use with MMA operations (returns raw UInt32 offset)

mma_op.mma(a_tile, b_tile, tmem.offset(), init_c=True)

Load fragments for epilogue

var upper = tmem.load_upperrepeat=4 var lower = tmem.load_lowerrepeat=4 TmemTensor.wait_load()

Parameters

  • dtype (DType): Accumulator data type (typically float32).
  • layout (Layout): Logical layout of the accumulator tile (M × N).
  • cta_group (Int): CTA cooperation level (1 or 2).

Fields

  • col_addr (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

bits

comptime bits = 256

data_paths

comptime data_paths = 16

frag_size

comptime frag_size = 4

Fragments

comptime Fragments = TmemFragments[dtype, 4, is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]

is_lower_required

comptime is_lower_required = (TmemTensor[dtype, layout, cta_group=cta_group].tile_m == 64) if (eq cta_group._mlir_value, 1) else (cta_group == 1).__bool__().__invert__()

tile_m

comptime tile_m = layout.shape[0].value()

Methods

__init__

__init__(col_addr: Int) -> Self

Create TMEM tensor view at the given column address.

__init__(addr: TmemAddress) -> Self

Create TMEM tensor view from a TmemAddress.

offset

offset(self) -> Int

TMEM column address for this tensor.

Returns:

Int

address

address(self) -> TmemAddress

Get TmemAddress for low-level fragment operations.

Returns:

TmemAddress

load_upper

load_upper[repeat: Int = 1](self) -> SIMD[dtype, (4 * repeat)]

Load upper accumulator fragment (rows 0-15).

Parameters:

  • repeat (Int): Number of times to repeat the load pattern.

Returns:

SIMD: SIMD vector containing the upper fragment data.

load_lower

load_lower[repeat: Int = 1](self) -> SIMD[dtype, (4 * repeat)]

Load lower accumulator fragment (rows 16-31).

Parameters:

  • repeat (Int): Number of times to repeat the load pattern.

Returns:

SIMD: SIMD vector containing the lower fragment data.

store_upper

store_upper[repeat: Int = 1](self, data: SIMD[dtype, (4 * repeat)])

Store upper accumulator fragment (rows 0-15).

Parameters:

  • repeat (Int): Number of times to repeat the store pattern.

Args:

  • data (SIMD): SIMD vector containing the data to store.

store_lower

store_lower[repeat: Int = 1](self, data: SIMD[dtype, (4 * repeat)])

Store lower accumulator fragment (rows 16-31).

Parameters:

  • repeat (Int): Number of times to repeat the store pattern.

Args:

  • data (SIMD): SIMD vector containing the data to store.

load_fragments

load_fragments[repeat: Int = 1](self) -> TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]

Load both upper and lower fragments in one call.

Handles is_lower_required automatically based on layout.

Parameters:

  • repeat (Int): Number of times to repeat the load pattern.

Returns:

TmemFragments: TmemFragments containing upper and (conditionally) lower data.

store_fragments

store_fragments[repeat: Int = 1](self, frags: TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required])

Store both upper and lower fragments in one call.

Handles is_lower_required automatically based on layout.

Parameters:

  • repeat (Int): Number of times to repeat the store pattern.

Args:

  • frags (TmemFragments): TmemFragments containing upper and (conditionally) lower data.

wait_load

static wait_load()

Wait for TMEM load operations to complete.

wait_store

static wait_store()

Wait for TMEM store operations to complete.

Was this page helpful?