Mojo struct
TmemStage
struct TmemStage[opc: OutputPipelineConfig]
A pipeline stage within TMEM for accumulator buffering.
Used by OutputTilePipeline to manage MMA→Epilogue synchronization. MMA writes to one stage while epilogue reads from another.
Wraps TmemAddress with stage-specific offset calculation:
- offset(): Column address for this stage (base + index * stride)
- address(): TmemAddress for this stage (for load/store ops)
- tensorlayout: Get typed TmemTensor view
Parameters
- opc (
OutputPipelineConfig): Output pipeline configuration (stages, stride, cta_group).
Fields
- base_addr (
Int): - index (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
cta_group
comptime cta_group = opc.cta_group
num_stages
comptime num_stages = opc.num_stages
stage_stride
comptime stage_stride = opc.stage_stride_cols
Methods
__init__
__init__(base_addr: Int, index: Int) -> Self
__init__(addr: TmemAddress, index: Int) -> Self
Create stage from TmemAddress and stage index.
__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols], index: Int) -> Self
Create stage from TmemAllocation and stage index.
from_offset
static from_offset(offset: Int, index: Int) -> Self
Create stage from pre-computed offset (for legacy pipeline compatibility).
Use this when the caller has already computed the TMEM offset
(e.g., base + stage * stride) and just needs to wrap it.
The index is preserved for barrier signaling, and we back-calculate the base_addr such that offset() = base + index * stride = offset.
Args:
- offset (
Int): Pre-computed TMEM column offset for this stage. - index (
Int): Pipeline stage index (for barrier signaling).
Returns:
Self: TmemStage with offset() returning the given value.
offset
address
tensor
tensor[accum_dtype: DType, accum_layout: Layout](self) -> TmemTensor[accum_dtype, accum_layout, cta_group=TmemStage[opc].cta_group]
Get typed TmemTensor view of this stage's accumulator.
Parameters:
- accum_dtype (
DType): Accumulator data type. - accum_layout (
Layout): Logical accumulator layout (M × N).
Returns:
TmemTensor: TmemTensor providing typed access to the accumulator.
load_upper
load_upper[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> InlineArray[Scalar[dtype], frag_size]
Load upper accumulator fragment (rows 0-15).
Returns:
load_lower
load_lower[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> InlineArray[Scalar[dtype], frag_size]
Load lower accumulator fragment (rows 16-31).
Returns:
wait_load
static wait_load()
Wait for TMEM load operations to complete.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!