IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

EpiLoadPipeline

struct EpiLoadPipeline[num_stages: Int]

Pipeline for epilogue load warp to epilogue store warps.

Producer: Epilogue load warp (1 warp, 32 threads) Consumer: Epilogue store warps (4 warps, 128 threads)

This pipeline synchronizes source tensor C loading with epilogue consumption for residual add operations (D = accum + beta * C).

Barrier Layout in SMEM​

The pipeline uses 2 Γ— num_stages barriers:

  • Full barriers [0..num_stages): Producer signals data ready
  • Empty barriers [num_stages..2*num_stages): Consumer signals stage free

Arrive Counts​

  • Producer arrive count: 1 (single TMA transaction per stage)
  • Consumer arrive count: 128 (all epilogue threads)

Parameters​

  • ​num_stages (Int): Number of pipeline stages (typically 2 for double-buffering).

Fields​

  • ​pipeline (ProducerConsumerPipeline[num_stages]):

Implemented traits​

AnyType, ImplicitlyDeletable

Methods​

__init__​

def __init__(out self, ptr: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED])

Initialize the epilogue load pipeline.

Args:

init_barriers​

def init_barriers(self, producer_arv_count: Int32 = Int32(1), consumer_arv_count: Int32 = Int32(128))

Initialize the pipeline barriers.

Should be called by a single thread (elect_one_thread) during kernel initialization.

Args:

  • ​producer_arv_count (Int32): Arrive count for producer (default 1 for TMA).
  • ​consumer_arv_count (Int32): Arrive count for consumer (default 128 for 4 epilogue warps Γ— 32 threads).

produce​

def produce[origin: MutOrigin, //](ref[num_stages] self) -> ProduceContext[origin_of(origin.pipeline), num_stages]

Produce one pipeline stage with encapsulated barriers.

Usage: with epi_load_pipeline.produce() as stage: tma_load(c_tile, stage.mbar(), ...) # exit advances producer stage

Returns:

ProduceContext[origin_of(origin.pipeline), num_stages]: Context that waits for consumer on enter, advances on exit.

acquire_producer​

def acquire_producer[origin: MutOrigin, //](ref[num_stages] self) -> ProducerStage[origin_of(origin.pipeline), num_stages]

Acquire a producer stage handle.

Returns:

ProducerStage[origin_of(origin.pipeline), num_stages]: ProducerStage handle that must be released.

wait_consumer​

def wait_consumer(self)

Wait for consumer to free the current stage.

producer_mbar​

def producer_mbar(self) -> MbarPtr

Get the producer barrier for the current stage.

Returns:

MbarPtr: Barrier pointer for TMA arrive.

producer_step​

def producer_step(mut self)

Advance producer to next stage.

consume​

def consume[origin: MutOrigin, //](ref[num_stages] self) -> ConsumeContext[origin_of(origin.pipeline), num_stages]

Consume one pipeline stage with encapsulated barriers.

Usage: with epi_load_pipeline.consume() as stage: c_tile = smem.src_tiles()[stage.index()] # Use C tile for residual add # exit signals consumption and advances

Returns:

ConsumeContext[origin_of(origin.pipeline), num_stages]: Context that waits for producer on enter, signals+advances on exit.

consume_explicit​

def consume_explicit[origin: MutOrigin, //](ref[num_stages] self) -> ExplicitConsumeContext[origin_of(origin.pipeline), num_stages]

Consume with explicit barrier arrive.

Use for lane-guarded signaling patterns.

Returns:

ExplicitConsumeContext[origin_of(origin.pipeline), num_stages]: Context that waits on enter, advances only on exit.

acquire_consumer​

def acquire_consumer[origin: MutOrigin, //](ref[num_stages] self) -> ConsumerStage[origin_of(origin.pipeline), num_stages]

Acquire a consumer stage handle.

Returns:

ConsumerStage[origin_of(origin.pipeline), num_stages]: ConsumerStage handle that must be released.

wait_producer​

def wait_producer(self)

Wait for producer to fill the current stage.

consumer_stage​

def consumer_stage(self) -> UInt32

Get the current consumer stage index.

Returns:

UInt32

consumer_step​

def consumer_step(mut self)

Advance consumer to next stage.