IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

struct TmemTensor[dtype: DType, layout: Layout, *, cta_group: Int = 1]

Typed tensor view over Tensor Memory (TMEM) for MMA accumulators.

Provides a typed abstraction for TMEM with:

  • Type safety: dtype and layout known at compile time
  • Fragment access: upper (rows 0-15) and lower (rows 16-31)
  • MMA integration: offset() returns raw address for MMA operations

The layout parameter captures the logical accumulator shape (M Γ— N), enabling future extensions like custom tiling patterns or multi-tile accumulator management.

Example:

Create typed TMEM view with (64, 128) accumulator layout

comptime layout = Layout.row_major(64, 128) var tmem = TmemTensorDType.float32, layout

Use with MMA operations (returns raw UInt32 offset)

mma_op.mma(a_tile, b_tile, tmem.offset(), init_c=True)

Load fragments for epilogue

var upper = tmem.load_upperrepeat=4 var lower = tmem.load_lowerrepeat=4 TmemTensor.wait_load()

Parameters​

  • ​dtype (DType): Accumulator data type (typically float32).
  • ​layout (Layout): Logical layout of the accumulator tile (M Γ— N).
  • ​cta_group (Int): CTA cooperation level (1 or 2).

Fields​

  • ​col_addr (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

bits​

comptime bits = 256

data_paths​

comptime data_paths = 16

frag_size​

comptime frag_size = 4

Fragments​

comptime Fragments = TmemFragments[dtype, 4, is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]

is_lower_required​

comptime is_lower_required = not (TmemTensor[dtype, layout, cta_group=cta_group].tile_m == 64) if (cta_group == 1) else (cta_group == 1).__bool__()

tile_m​

comptime tile_m = layout.shape[0].value()

Methods​

__init__​

def __init__(col_addr: Int) -> Self

Create TMEM tensor view at the given column address.

def __init__(addr: TmemAddress) -> Self

Create TMEM tensor view from a TmemAddress.

offset​

def offset(self) -> Int

TMEM column address for this tensor.

Returns:

Int

address​

def address(self) -> TmemAddress

Get TmemAddress for low-level fragment operations.

Returns:

TmemAddress

load_upper​

def load_upper[repeat: Int = 1](self) -> InlineArray[Scalar[dtype], (4 * repeat)]

Load upper accumulator fragment (rows 0-15).

Parameters:

  • ​repeat (Int): Number of times to repeat the load pattern.

Returns:

InlineArray[Scalar[dtype], (4 * repeat)]: InlineArray containing the upper fragment data.

load_lower​

def load_lower[repeat: Int = 1](self) -> InlineArray[Scalar[dtype], (4 * repeat)]

Load lower accumulator fragment (rows 16-31).

Parameters:

  • ​repeat (Int): Number of times to repeat the load pattern.

Returns:

InlineArray[Scalar[dtype], (4 * repeat)]: InlineArray containing the lower fragment data.

store_upper​

def store_upper[repeat: Int = 1](self, data: InlineArray[Scalar[dtype], (4 * repeat)])

Store upper accumulator fragment (rows 0-15).

Parameters:

  • ​repeat (Int): Number of times to repeat the store pattern.

Args:

store_lower​

def store_lower[repeat: Int = 1](self, data: InlineArray[Scalar[dtype], (4 * repeat)])

Store lower accumulator fragment (rows 16-31).

Parameters:

  • ​repeat (Int): Number of times to repeat the store pattern.

Args:

load_fragments​

def load_fragments[repeat: Int = 1](self) -> TmemFragments[dtype, (4 * repeat), is_lower_required=Self.is_lower_required]

Load both upper and lower fragments in one call.

Handles is_lower_required automatically based on layout.

Parameters:

  • ​repeat (Int): Number of times to repeat the load pattern.

Returns:

TmemFragments[dtype, (4 * repeat), is_lower_required=Self.is_lower_required]: TmemFragments containing upper and (conditionally) lower data.

store_fragments​

def store_fragments[repeat: Int = 1](self, frags: TmemFragments[dtype, (4 * repeat), is_lower_required=Self.is_lower_required])

Store both upper and lower fragments in one call.

Handles is_lower_required automatically based on layout.

Parameters:

  • ​repeat (Int): Number of times to repeat the store pattern.

Args:

wait_load​

static def wait_load()

Wait for TMEM load operations to complete.

wait_store​

static def wait_store()

Wait for TMEM store operations to complete.