Skip to main content

Mojo module

pipeline

Producer-consumer pipeline utilities for SM100 structured kernels.

This module provides pipeline synchronization primitives for warp-specialized GPU kernels, enabling efficient producer-consumer patterns between warps.

Key abstraction:

  • ProducerConsumerPipeline: Low-level barrier management for N-stage pipelines
  • ProducerStage / ConsumerStage: Unified stage handles (linear types)

Unified Stage Types

ProducerStage and ConsumerStage are linear types (@explicit_destroy) that work in both contexts:

  1. Linear Type API (flat, explicit): var stage = pipeline.acquire_producer()

    ... use stage.index(), stage.mbar() ...

    stage^.release() # Compiler enforces this call

  2. Context Manager API (scoped, automatic): with pipeline.produce() as stage: # ... use stage.index(), stage.mbar() ...

    release() called automatically

The context managers store the stage internally and return a ref to it, allowing access to the full stage API while managing lifetime automatically.

API Examples

Producer side (e.g., MMA warp producing to epilogue):

# Context manager:
with pipeline.produce() as stage:
    mma_op.mma(a, b, tmem_offset)
    mma_op.commit(stage.mbar())
# __exit__ calls stage^.release() -> producer_step()

# Linear type:
var stage = pipeline.acquire_producer()
mma_op.mma(a, b, tmem_offset)
mma_op.commit(stage.mbar())
stage^.release()

Consumer side (e.g., epilogue consuming from MMA):

# Context manager:
with pipeline.consume() as stage:
    process(stage.index())
# __exit__ calls stage^.release() -> arrive + consumer_step()

# Linear type:
var stage = pipeline.acquire_consumer()
process(stage.index())
stage^.release()  # Signal + advance

# Explicit signaling:
var stage = pipeline.acquire_consumer()
if lane_id() < CLUSTER_SIZE:
    stage.arrive()
stage^.release_without_signal()  # Advance only

Direct API (for special cases): pipeline.wait_producer() / wait_consumer() pipeline.producer_step() / consumer_step() pipeline.producer_mbar(stage) / consumer_mbar(stage)

comptime values

MbarPtr

comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]

UnsafePointer

comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]

Structs

Was this page helpful?