Skip to main content

Mojo struct

TmemStage

struct TmemStage[opc: OutputPipelineConfig]

A pipeline stage within TMEM for accumulator buffering.

Used by OutputTilePipeline to manage MMA→Epilogue synchronization. MMA writes to one stage while epilogue reads from another.

Wraps TmemAddress with stage-specific offset calculation:

  • offset(): Column address for this stage (base + index * stride)
  • address(): TmemAddress for this stage (for load/store ops)
  • tensorlayout: Get typed TmemTensor view

Parameters

Fields

  • base_addr (Int):
  • index (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

cta_group

comptime cta_group = opc.cta_group

num_stages

comptime num_stages = opc.num_stages

stage_stride

comptime stage_stride = opc.stage_stride_cols

Methods

__init__

__init__(base_addr: Int, index: Int) -> Self

__init__(addr: TmemAddress, index: Int) -> Self

Create stage from TmemAddress and stage index.

__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols], index: Int) -> Self

Create stage from TmemAllocation and stage index.

from_offset

static from_offset(offset: Int, index: Int) -> Self

Create stage from pre-computed offset (for legacy pipeline compatibility).

Use this when the caller has already computed the TMEM offset (e.g., base + stage * stride) and just needs to wrap it.

The index is preserved for barrier signaling, and we back-calculate the base_addr such that offset() = base + index * stride = offset.

Args:

  • offset (Int): Pre-computed TMEM column offset for this stage.
  • index (Int): Pipeline stage index (for barrier signaling).

Returns:

Self: TmemStage with offset() returning the given value.

offset

offset(self) -> Int

TMEM column address for this stage.

Returns:

Int

address

address(self) -> TmemAddress

Get TmemAddress for this stage's offset.

Returns:

TmemAddress

tensor

tensor[accum_dtype: DType, accum_layout: Layout](self) -> TmemTensor[accum_dtype, accum_layout, cta_group=TmemStage[opc].cta_group]

Get typed TmemTensor view of this stage's accumulator.

Parameters:

  • accum_dtype (DType): Accumulator data type.
  • accum_layout (Layout): Logical accumulator layout (M × N).

Returns:

TmemTensor: TmemTensor providing typed access to the accumulator.

load_upper

load_upper[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> InlineArray[Scalar[dtype], frag_size]

Load upper accumulator fragment (rows 0-15).

Returns:

InlineArray

load_lower

load_lower[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> InlineArray[Scalar[dtype], frag_size]

Load lower accumulator fragment (rows 16-31).

Returns:

InlineArray

wait_load

static wait_load()

Wait for TMEM load operations to complete.

Was this page helpful?