Mojo struct
TmemStage
@register_passable(trivial)
struct TmemStage[num_stages: Int, stage_stride: Int, cta_group: Int]
A pipeline stage within TMEM for accumulator buffering.
Used by OutputTilePipeline to manage MMA→Epilogue synchronization. MMA writes to one stage while epilogue reads from another.
Wraps TmemAddress with stage-specific offset calculation:
- offset(): Column address for this stage (base + index * stride)
- address(): TmemAddress for this stage (for load/store ops)
- tensorlayout: Get typed TmemTensor view
Parameters
- num_stages (
Int): Pipeline stages (typically 2-4). - stage_stride (
Int): Columns per stage (512 / num_stages). - cta_group (
Int): Cooperating CTAs (1 or 2).
Fields
- base_addr (
Int): - index (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterType,
TrivialRegisterType
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Methods
__init__
__init__(base_addr: Int, index: Int) -> Self
__init__(addr: TmemAddress, index: Int) -> Self
Create stage from TmemAddress and stage index.
__init__[cta: Int, max_cols: Int](alloc: TmemAllocation[cta, max_cols], index: Int) -> Self
Create stage from TmemAllocation and stage index.
from_offset
static from_offset(offset: Int, index: Int) -> Self
Create stage from pre-computed offset (for legacy pipeline compatibility).
Use this when the caller has already computed the TMEM offset
(e.g., base + stage * stride) and just needs to wrap it.
The index is preserved for barrier signaling, and we back-calculate the base_addr such that offset() = base + index * stride = offset.
Args:
- offset (
Int): Pre-computed TMEM column offset for this stage. - index (
Int): Pipeline stage index (for barrier signaling).
Returns:
Self: TmemStage with offset() returning the given value.
offset
address
tensor
tensor[accum_dtype: DType, accum_layout: Layout](self) -> TmemTensor[accum_dtype, accum_layout, cta_group=cta_group]
Get typed TmemTensor view of this stage's accumulator.
Parameters:
- accum_dtype (
DType): Accumulator data type. - accum_layout (
Layout): Logical accumulator layout (M × N).
Returns:
TmemTensor: TmemTensor providing typed access to the accumulator.
load_upper
load_upper[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> SIMD[dtype, frag_size]
Load upper accumulator fragment (rows 0-15).
Returns:
load_lower
load_lower[dtype: DType, frag_size: Int, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4](self) -> SIMD[dtype, frag_size]
Load lower accumulator fragment (rows 16-31).
Returns:
wait_load
static wait_load()
Wait for TMEM load operations to complete.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!