Skip to main content

Mojo struct

ProducerConsumerPipeline

@register_passable(trivial) struct ProducerConsumerPipeline[num_stages: Int]

A producer-consumer pipeline using shared memory barriers to enforce synchronization (between producer and consumer warps).

This struct is commonly used with warp specialization to pipeline operations between two warps/warpgroups with data dependencies.

Parameters

  • num_stages (Int): The number of pipeline stages.

Fields

  • full (MbarPtr):
  • empty (MbarPtr):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Methods

__init__

__init__(ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self

Initialize the producer-consumer pipeline with default phases.

Args:

wait_producer

wait_producer(self)

Consumer waits for producer.

wait_consumer

wait_consumer(self)

Producer waits for consumer.

try_wait_producer

try_wait_producer(self) -> Bool

Non-blocking check if producer data is ready.

Note: Use this with wait_producer_if_needed() for the try-acquire pattern:

var ready = pipeline.try_wait_producer()
# ... do other work ...
pipeline.wait_producer_if_needed(ready)

Returns:

Bool: True if the producer has filled the current stage, False otherwise.

try_wait_consumer

try_wait_consumer(self) -> Bool

Non-blocking check if consumer has freed the stage.

Note: Use this with wait_consumer_if_needed() for the try-acquire pattern.

Returns:

Bool: True if the consumer has freed the current stage, False otherwise.

wait_producer_if_needed

wait_producer_if_needed(self, already_ready: Bool)

Conditionally wait for producer if not already ready.

Args:

  • already_ready (Bool): Result from try_wait_producer(). If True, skips waiting.

wait_consumer_if_needed

wait_consumer_if_needed(self, already_ready: Bool)

Conditionally wait for consumer if not already ready.

Args:

  • already_ready (Bool): Result from try_wait_consumer(). If True, skips waiting.

producer_mbar

producer_mbar(self, stage: UInt32) -> MbarPtr

Get the producer barrier for a specific stage.

Args:

  • stage (UInt32): The pipeline stage.

Returns:

MbarPtr: The shared memory barrier that the producer signals.

consumer_mbar

consumer_mbar(self, stage: UInt32) -> MbarPtr

Get the consumer barrier for a specific stage.

Args:

  • stage (UInt32): The pipeline stage.

Returns:

MbarPtr: The shared memory barrier that the consumer signals.

producer_stage

producer_stage(self) -> UInt32

Get the current producer stage index.

Returns:

UInt32: The current stage index for the producer (0 to num_stages-1).

consumer_stage

consumer_stage(self) -> UInt32

Get the current consumer stage index.

Returns:

UInt32: The current stage index for the consumer (0 to num_stages-1).

consumer_step

consumer_step(mut self)

Advance the consumer to the next pipeline stage.

Increments the consumer stage and wraps to 0 when reaching num_stages, toggling the phase bit on wrap-around. Only switch phase at end of pipeline because we assume all barriers are at the same consumer/producer phase before checked. Once checked, the execution moves to next barrier.

producer_step

producer_step(mut self)

Advance the producer to the next pipeline stage.

Increments the producer stage and wraps to 0 when reaching num_stages, toggling the phase bit on wrap-around.

smem_bytes

static smem_bytes() -> UInt32

Calculate the shared memory bytes required for pipeline barriers.

Returns:

UInt32: The total number of bytes needed for all pipeline barriers (2 * num_stages barriers).

init_mbars

init_mbars(self, producer_arrive_count: Int32, consumer_arrive_count: Int32)

Initialize the smem barriers for the producer and consumer.

This function must be called by a single thread and must be called before any the pipeline object is used.

Args:

  • producer_arrive_count (Int32): The number of threads that will arrive at the barrier marking data as produced.
  • consumer_arrive_count (Int32): The number of threads that will arrive at the barrier marking data as consumed.

producer_signal_and_step

producer_signal_and_step(mut self)

Wait for consumer, signal production, and advance stage.

Combined operation for CLC throttling (Load warp):

  1. Wait for consumer to finish with current stage
  2. Signal that producer has new data
  3. Advance to next stage

consumer_signal_and_step

consumer_signal_and_step(mut self)

Wait for producer, signal consumption, and advance stage.

Combined operation for CLC throttling (Scheduler warp):

  1. Wait for producer to have data ready
  2. Signal that consumer has consumed data
  3. Advance to next stage

produce

produce[origin: MutOrigin, //](ref[origin] self) -> ProduceContext[origin, num_stages]

Produce one pipeline stage with encapsulated barriers.

Usage: with pipeline.produce() as stage: # stage.index() gives current stage # stage.mbar() gives barrier for signaling # exit calls producer_step()

Returns:

ProduceContext: Context that waits for consumer on enter, advances on exit.

consume

consume[origin: MutOrigin, //](ref[origin] self) -> ConsumeContext[origin, num_stages]

Consume one pipeline stage with encapsulated barriers.

Usage: with pipeline.consume() as stage: # stage.index() gives current stage # exit signals consumer done and advances

Returns:

ConsumeContext: Context that waits for producer on enter, signals+advances on exit.

consume_explicit

consume_explicit[origin: MutOrigin, //](ref[origin] self) -> ExplicitConsumeContext[origin, num_stages]

Consume one pipeline stage with EXPLICIT barrier arrive.

Use this for kernels requiring lane-guarded or specialized signaling.

Usage: with pipeline.consume_explicit() as stage: # ... do work ... if lane_id() < CLUSTER_SIZE: stage.arrive() # Lane-guarded arrive # exit only advances, does NOT arrive

For specialized signaling (e.g., umma_arrive_leader_cta): with pipeline.consume_explicit() as stage: if cta_group == 1: stage.arrive() else: umma_arrive_leader_cta(stage.mbar())

Returns:

ExplicitConsumeContext: Context that waits for producer on enter, advances only on exit.

acquire_producer

acquire_producer[origin: MutOrigin, //](ref[origin] self) -> ProducerStage[origin, num_stages]

Acquire a producer stage handle using linear types.

Waits for the consumer to free the current stage, then returns a linear type handle that MUST be released (compiler-enforced).

Usage: var stage = pipeline.acquire_producer() # ... produce data, signal via stage.mbar() ... stage^.release() # Advances to next stage

Returns:

ProducerStage: A ProducerStage handle that must be released.

acquire_consumer

acquire_consumer[origin: MutOrigin, //](ref[origin] self) -> ConsumerStage[origin, num_stages]

Acquire a consumer stage handle using linear types.

Waits for the producer to fill the current stage, then returns a linear type handle that MUST be released (compiler-enforced).

Usage: var stage = pipeline.acquire_consumer() # ... consume data ... stage^.release() # Signals complete and advances

For explicit signaling: var stage = pipeline.acquire_consumer() # ... consume data ... if lane_id() < CLUSTER_SIZE: stage.arrive() stage^.release_without_signal()

Returns:

ConsumerStage: A ConsumerStage handle that must be released.

Was this page helpful?