Skip to main content

Mojo struct

EpiLoadPipeline

struct EpiLoadPipeline[num_stages: Int]

Pipeline for epilogue load warp to epilogue store warps.

Producer: Epilogue load warp (1 warp, 32 threads) Consumer: Epilogue store warps (4 warps, 128 threads)

This pipeline synchronizes source tensor C loading with epilogue consumption for residual add operations (D = accum + beta * C).

Barrier Layout in SMEM

The pipeline uses 2 × num_stages barriers:

  • Full barriers [0..num_stages): Producer signals data ready
  • Empty barriers [num_stages..2*num_stages): Consumer signals stage free

Arrive Counts

  • Producer arrive count: 1 (single TMA transaction per stage)
  • Consumer arrive count: 128 (all epilogue threads)

Parameters

  • num_stages (Int): Number of pipeline stages (typically 2 for double-buffering).

Fields

  • pipeline (ProducerConsumerPipeline[num_stages]):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

Methods

__init__

__init__(out self, ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])

Initialize the epilogue load pipeline.

Args:

  • ptr (LegacyUnsafePointer): Pointer to shared memory barrier storage. Requires 2 × num_stages SharedMemBarrier slots.

init_barriers

init_barriers(self, producer_arv_count: Int32 = 1, consumer_arv_count: Int32 = 128)

Initialize the pipeline barriers.

Should be called by a single thread (elect_one_thread) during kernel initialization.

Args:

  • producer_arv_count (Int32): Arrive count for producer (default 1 for TMA).
  • consumer_arv_count (Int32): Arrive count for consumer (default 128 for 4 epilogue warps × 32 threads).

produce

produce[origin: MutOrigin, //](ref[origin] self) -> ProduceContext[origin_of(origin._mlir_origin.pipeline), num_stages]

Produce one pipeline stage with encapsulated barriers.

Usage: with epi_load_pipeline.produce() as stage: tma_load(c_tile, stage.mbar(), ...) # exit advances producer stage

Returns:

ProduceContext: Context that waits for consumer on enter, advances on exit.

acquire_producer

acquire_producer[origin: MutOrigin, //](ref[origin] self) -> ProducerStage[origin_of(origin._mlir_origin.pipeline), num_stages]

Acquire a producer stage handle.

Returns:

ProducerStage: ProducerStage handle that must be released.

wait_consumer

wait_consumer(self)

Wait for consumer to free the current stage.

producer_mbar

producer_mbar(self) -> MbarPtr

Get the producer barrier for the current stage.

Returns:

MbarPtr: Barrier pointer for TMA arrive.

producer_step

producer_step(mut self)

Advance producer to next stage.

consume

consume[origin: MutOrigin, //](ref[origin] self) -> ConsumeContext[origin_of(origin._mlir_origin.pipeline), num_stages]

Consume one pipeline stage with encapsulated barriers.

Usage: with epi_load_pipeline.consume() as stage: c_tile = smem.src_tiles()[stage.index()] # Use C tile for residual add # exit signals consumption and advances

Returns:

ConsumeContext: Context that waits for producer on enter, signals+advances on exit.

consume_explicit

consume_explicit[origin: MutOrigin, //](ref[origin] self) -> ExplicitConsumeContext[origin_of(origin._mlir_origin.pipeline), num_stages]

Consume with explicit barrier arrive.

Use for lane-guarded signaling patterns.

Returns:

ExplicitConsumeContext: Context that waits on enter, advances only on exit.

acquire_consumer

acquire_consumer[origin: MutOrigin, //](ref[origin] self) -> ConsumerStage[origin_of(origin._mlir_origin.pipeline), num_stages]

Acquire a consumer stage handle.

Returns:

ConsumerStage: ConsumerStage handle that must be released.

wait_producer

wait_producer(self)

Wait for producer to fill the current stage.

consumer_stage

consumer_stage(self) -> UInt32

Get the current consumer stage index.

Returns:

UInt32

consumer_step

consumer_step(mut self)

Advance consumer to next stage.

Was this page helpful?