Mojo struct
EpiLoadPipeline
struct EpiLoadPipeline[num_stages: Int]
Pipeline for epilogue load warp to epilogue store warps.
Producer: Epilogue load warp (1 warp, 32 threads) Consumer: Epilogue store warps (4 warps, 128 threads)
This pipeline synchronizes source tensor C loading with epilogue consumption for residual add operations (D = accum + beta * C).
Barrier Layout in SMEM
The pipeline uses 2 × num_stages barriers:
- Full barriers [0..num_stages): Producer signals data ready
- Empty barriers [num_stages..2*num_stages): Consumer signals stage free
Arrive Counts
- Producer arrive count: 1 (single TMA transaction per stage)
- Consumer arrive count: 128 (all epilogue threads)
Parameters
- num_stages (
Int): Number of pipeline stages (typically 2 for double-buffering).
Fields
- pipeline (
ProducerConsumerPipeline[num_stages]):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
Methods
__init__
__init__(out self, ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])
Initialize the epilogue load pipeline.
Args:
- ptr (
LegacyUnsafePointer): Pointer to shared memory barrier storage. Requires 2 × num_stages SharedMemBarrier slots.
init_barriers
init_barriers(self, producer_arv_count: Int32 = 1, consumer_arv_count: Int32 = 128)
Initialize the pipeline barriers.
Should be called by a single thread (elect_one_thread) during kernel initialization.
Args:
produce
produce[origin: MutOrigin, //](ref[origin] self) -> ProduceContext[origin_of(origin._mlir_origin.pipeline), num_stages]
Produce one pipeline stage with encapsulated barriers.
Usage: with epi_load_pipeline.produce() as stage: tma_load(c_tile, stage.mbar(), ...) # exit advances producer stage
Returns:
ProduceContext: Context that waits for consumer on enter, advances on exit.
acquire_producer
acquire_producer[origin: MutOrigin, //](ref[origin] self) -> ProducerStage[origin_of(origin._mlir_origin.pipeline), num_stages]
Acquire a producer stage handle.
Returns:
ProducerStage: ProducerStage handle that must be released.
wait_consumer
wait_consumer(self)
Wait for consumer to free the current stage.
producer_mbar
producer_mbar(self) -> MbarPtr
Get the producer barrier for the current stage.
Returns:
MbarPtr: Barrier pointer for TMA arrive.
producer_step
producer_step(mut self)
Advance producer to next stage.
consume
consume[origin: MutOrigin, //](ref[origin] self) -> ConsumeContext[origin_of(origin._mlir_origin.pipeline), num_stages]
Consume one pipeline stage with encapsulated barriers.
Usage: with epi_load_pipeline.consume() as stage: c_tile = smem.src_tiles()[stage.index()] # Use C tile for residual add # exit signals consumption and advances
Returns:
ConsumeContext: Context that waits for producer on enter, signals+advances on exit.
consume_explicit
consume_explicit[origin: MutOrigin, //](ref[origin] self) -> ExplicitConsumeContext[origin_of(origin._mlir_origin.pipeline), num_stages]
Consume with explicit barrier arrive.
Use for lane-guarded signaling patterns.
Returns:
ExplicitConsumeContext: Context that waits on enter, advances only on exit.
acquire_consumer
acquire_consumer[origin: MutOrigin, //](ref[origin] self) -> ConsumerStage[origin_of(origin._mlir_origin.pipeline), num_stages]
Acquire a consumer stage handle.
Returns:
ConsumerStage: ConsumerStage handle that must be released.
wait_producer
wait_producer(self)
Wait for producer to fill the current stage.
consumer_stage
consumer_step
consumer_step(mut self)
Advance consumer to next stage.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!