Mojo struct
OutputRingBuffer
@register_passable(trivial)
struct OutputRingBuffer[num_stages: Int, stage_stride_cols: Int, cta_group: Int]
Ring buffer for MMA→Epilogue output pipeline.
Manages TMEM accumulator stage synchronization between MMA warps (producer) and Epilogue warps (consumer). Unlike RingBuffer which manages SMEM tiles, this manages stage indices and computes TMEM offsets.
The TMEM itself is allocated separately via tcgen05_alloc; this struct only coordinates access to different stages within that allocation.
Template Parameters: num_stages: Number of accumulator pipeline stages. stage_stride_cols: TMEM column stride between stages. cta_group: CTA group size (1 or 2) for multicast signaling.
Usage: # Initialize barriers once (elect_one_warp/elect_one_thread): OutputRingBuffer[...].init_barriers(storage_ptr, prod_cnt, cons_cnt)
# Create ring buffer (each warp creates its own):
var output_rb = OutputRingBuffer[...](storage_ptr, tmem_addr, mask)
# MMA warp (producer):
with output_rb.producer() as stage:
# ... perform MMA into stage.tmem_offset ...
# Epilogue warp (consumer):
with output_rb.consumer() as stage:
# ... read from stage.tmem_offset, write to GMEM ...Fields
- pipeline (
OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Pipeline): - tmem_base_addr (
UInt32): - mma_complete_mask (
UInt16):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Pipeline
comptime Pipeline = ProducerConsumerPipeline[num_stages]
Stage
comptime Stage = OutputStage[num_stages]
Methods
__init__
__init__(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tmem_base_addr: UInt32, mma_complete_mask: UInt16) -> Self
Initialize output ring buffer.
Creates pipeline internally from storage pointer. Barriers must be initialized via init_barriers() before first use.
Args:
- storage_ptr (
LegacyUnsafePointer): Pointer to shared memory barrier storage. - tmem_base_addr (
UInt32): Base TMEM address for accumulators. - mma_complete_mask (
UInt16): Multicast mask for 2-SM MMA completion signaling.
init_barriers
static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)
Initialize pipeline barriers. Called once by elect_one thread.
Args:
- storage_ptr (
LegacyUnsafePointer): Pointer to shared memory barrier storage. - producer_arv_count (
Int32): Expected arrival count for producer barriers. - consumer_arv_count (
Int32): Expected arrival count for consumer barriers.
acquire_for_mma
acquire_for_mma(self) -> OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Stage
Acquire a stage for MMA computation.
Waits for the epilogue to finish with this stage, then returns the stage info with computed TMEM offset and pipeline reference.
Returns:
OutputRingBuffer: OutputStage with stage index, TMEM offset, and pipeline for signaling.
release_from_mma
release_from_mma(mut self, stage: OutputStage[num_stages])
Signal MMA completion and advance to next stage.
Signals the epilogue that accumulator data is ready, using either mma_arrive (1-SM) or mma_arrive_multicast (2-SM).
Args:
- stage (
OutputStage): The stage being released (from acquire_for_mma).
acquire_for_epilogue
acquire_for_epilogue(self) -> OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Stage
Acquire a stage for epilogue processing.
Waits for MMA to complete this stage, then returns the stage info.
Returns:
OutputRingBuffer: OutputStage with stage index, TMEM offset, and pipeline for signaling.
release_from_epilogue
release_from_epilogue(mut self)
Signal epilogue completion and advance to next stage.
Signals MMA that this accumulator stage is free for reuse.
producer
producer[origin: MutOrigin, //](ref [origin] self) -> OutputProducerContext[origin, num_stages, stage_stride_cols, cta_group]
Get a producer context for MMA warp.
Usage: with output_rb.producer() as stage: # MMA into stage.tmem_offset # release_from_mma called automatically
Returns:
OutputProducerContext
consumer
consumer[origin: MutOrigin, //](ref [origin] self) -> OutputConsumerContext[origin, num_stages, stage_stride_cols, cta_group]
Get a consumer context for epilogue warp.
Usage: with output_rb.consumer() as stage: # Read from stage.tmem_offset, write to GMEM # release_from_epilogue called automatically
Returns:
OutputConsumerContext
get_pipeline
get_pipeline(self) -> OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Pipeline
Get the underlying pipeline for barrier initialization.
Note: With OutputStage now carrying the pipeline, most code no longer needs this. It's retained for init_barriers() which needs the raw pipeline before any OutputStage instances exist.
Returns:
OutputRingBuffer
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!