Skip to main content

Mojo struct

TmemStage

@register_passable(trivial) struct TmemStage[num_stages: Int, stage_stride: Int, cta_group: Int]

A pipeline stage within TMEM for accumulator buffering.

Used by OutputTilePipeline to manage MMA→Epilogue synchronization. MMA writes to one stage while epilogue reads from another.

Wraps TmemAddress with stage-specific offset calculation:

  • offset(): Column address for this stage (base + index * stride)
  • address(): TmemAddress for this stage (for load/store ops)
  • tensorlayout: Get typed TmemTensor view

Parameters

  • num_stages (Int): Pipeline stages (typically 2-4).
  • stage_stride (Int): Columns per stage (512 / num_stages).
  • cta_group (Int): Cooperating CTAs (1 or 2).

Fields

  • base_addr (Int):
  • index (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterType, TrivialRegisterType

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Methods

__init__

__init__(base_addr: Int, index: Int) -> Self

__init__(addr: TmemAddress, index: Int) -> Self

Create stage from TmemAddress and stage index.

__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols], index: Int) -> Self

Create stage from TmemAllocation and stage index.

from_offset

static from_offset(offset: Int, index: Int) -> Self

Create stage from pre-computed offset (for legacy pipeline compatibility).

Use this when the caller has already computed the TMEM offset (e.g., base + stage * stride) and just needs to wrap it.

The index is preserved for barrier signaling, and we back-calculate the base_addr such that offset() = base + index * stride = offset.

Args:

  • offset (Int): Pre-computed TMEM column offset for this stage.
  • index (Int): Pipeline stage index (for barrier signaling).

Returns:

Self: TmemStage with offset() returning the given value.

offset

offset(self) -> Int

TMEM column address for this stage.

Returns:

Int

address

address(self) -> TmemAddress

Get TmemAddress for this stage's offset.

Returns:

TmemAddress

tensor

tensor[accum_dtype: DType, accum_layout: Layout](self) -> TmemTensor[accum_dtype, accum_layout, cta_group=cta_group]

Get typed TmemTensor view of this stage's accumulator.

Parameters:

  • accum_dtype (DType): Accumulator data type.
  • accum_layout (Layout): Logical accumulator layout (M × N).

Returns:

TmemTensor: TmemTensor providing typed access to the accumulator.

load_upper

load_upper[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> SIMD[dtype, frag_size]

Load upper accumulator fragment (rows 0-15).

Returns:

SIMD

load_lower

load_lower[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> SIMD[dtype, frag_size]

Load lower accumulator fragment (rows 16-31).

Returns:

SIMD

wait_load

static wait_load()

Wait for TMEM load operations to complete.

Was this page helpful?