Mojo struct
EpiLoadPipeline
struct EpiLoadPipeline[num_stages: Int]
Pipeline for epilogue load warp to epilogue store warps.
Producer: Epilogue load warp (1 warp, 32 threads) Consumer: Epilogue store warps (4 warps, 128 threads)
This pipeline synchronizes source tensor C loading with epilogue consumption for residual add operations (D = accum + beta * C).
Barrier Layout in SMEMβ
The pipeline uses 2 Γ num_stages barriers:
- Full barriers [0..num_stages): Producer signals data ready
- Empty barriers [num_stages..2*num_stages): Consumer signals stage free
Arrive Countsβ
- Producer arrive count: 1 (single TMA transaction per stage)
- Consumer arrive count: 128 (all epilogue threads)
Parametersβ
- βnum_stages (
Int): Number of pipeline stages (typically 2 for double-buffering).
Fieldsβ
- βpipeline (
ProducerConsumerPipeline[num_stages]):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
__init__β
__init__(out self, ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED])
Initialize the epilogue load pipeline.
Args:
- βptr (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): Pointer to shared memory barrier storage. Requires 2 Γ num_stages SharedMemBarrier slots.
init_barriersβ
init_barriers(self, producer_arv_count: Int32 = Int32(1), consumer_arv_count: Int32 = Int32(128))
Initialize the pipeline barriers.
Should be called by a single thread (elect_one_thread) during kernel initialization.
Args:
produceβ
produce[origin: MutOrigin, //](ref[num_stages] self) -> ProduceContext[origin_of(origin.pipeline), num_stages]
Produce one pipeline stage with encapsulated barriers.
Usage: with epi_load_pipeline.produce() as stage: tma_load(c_tile, stage.mbar(), ...) # exit advances producer stage
Returns:
ProduceContext[origin_of(origin.pipeline), num_stages]: Context that waits for consumer on enter, advances on exit.
acquire_producerβ
acquire_producer[origin: MutOrigin, //](ref[num_stages] self) -> ProducerStage[origin_of(origin.pipeline), num_stages]
Acquire a producer stage handle.
Returns:
ProducerStage[origin_of(origin.pipeline), num_stages]: ProducerStage handle that must be released.
wait_consumerβ
wait_consumer(self)
Wait for consumer to free the current stage.
producer_mbarβ
producer_mbar(self) -> MbarPtr
Get the producer barrier for the current stage.
Returns:
MbarPtr: Barrier pointer for TMA arrive.
producer_stepβ
producer_step(mut self)
Advance producer to next stage.
consumeβ
consume[origin: MutOrigin, //](ref[num_stages] self) -> ConsumeContext[origin_of(origin.pipeline), num_stages]
Consume one pipeline stage with encapsulated barriers.
Usage: with epi_load_pipeline.consume() as stage: c_tile = smem.src_tiles()[stage.index()] # Use C tile for residual add # exit signals consumption and advances
Returns:
ConsumeContext[origin_of(origin.pipeline), num_stages]: Context that waits for producer on enter, signals+advances on exit.
consume_explicitβ
consume_explicit[origin: MutOrigin, //](ref[num_stages] self) -> ExplicitConsumeContext[origin_of(origin.pipeline), num_stages]
Consume with explicit barrier arrive.
Use for lane-guarded signaling patterns.
Returns:
ExplicitConsumeContext[origin_of(origin.pipeline), num_stages]: Context that waits on enter, advances only on exit.
acquire_consumerβ
acquire_consumer[origin: MutOrigin, //](ref[num_stages] self) -> ConsumerStage[origin_of(origin.pipeline), num_stages]
Acquire a consumer stage handle.
Returns:
ConsumerStage[origin_of(origin.pipeline), num_stages]: ConsumerStage handle that must be released.
wait_producerβ
wait_producer(self)
Wait for producer to fill the current stage.
consumer_stageβ
consumer_stepβ
consumer_step(mut self)
Advance consumer to next stage.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!