Skip to main content

Mojo module

tile_pipeline

Tile pipeline for SM100 producer-consumer synchronization.

Provides staged tile storage with producer-consumer barrier synchronization for TMA-MMA pipeline coordination. All barrier operations are encapsulated in context managers for safety and clarity.

Key Abstractions

  • InputTilePipeline[Payload]: Generic pipeline with payload abstraction
  • TilePipeline: Standard pipeline with explicit A/B tile types
  • OutputTilePipeline: TMEM accumulator stages for MMA→Epilogue pipeline

Naming Conventions

  • *Pipeline: Multi-stage buffer (InputTilePipeline, OutputTilePipeline)
  • *Producer/*Consumer: Role handles (InputProducer, OutputConsumer)
  • acquire(): Context manager to get one pipeline stage

Context Manager Semantics

Each with block handles barrier synchronization automatically:

with producer.acquire() as tiles:   # BLOCKS until consumer releases stage
    load_tiles(tiles)                # safe to write
                                     # EXIT: signals producer barrier, advances

with consumer.acquire() as tiles:   # BLOCKS until producer fills stage
    use_tiles(tiles)                 # safe to read
                                     # EXIT: signals consumer barrier, advances

Example: TMA Load Warp (Producer)

with input_pipeline.producer() as producer:  # producer role for this warp
    while work_iter.has_work():
        with work_iter.next() as current:
            for i in range(num_iters):
                with producer.acquire() as tiles:  # waits for consumer
                    tma_load(tiles.a_tile(), tiles.b_tile())
    producer.drain()  # wait for all stages consumed before CTA exits

Example: MMA Warp (Consumer + Output Producer)

with mma_ctx:  # TMEM lifecycle
    while work_iter.has_work():
        with work_iter.wait_and_advance():  # blocks on CLC response
            with output_pipeline.producer() as output_stage:  # waits for epilogue
                with input_pipeline.consumer() as consumer:
                    for i in range(num_iters):
                        with consumer.acquire() as input_tiles:  # waits for TMA
                            mma(output_stage.tmem, input_tiles)

Example: Epilogue Warp (Output Consumer)

with epi_ctx:  # signals TMEM dealloc on exit
    while work_iter.has_work():
        with work_iter.next() as current:
            with output_pipeline.consumer() as output_stage:  # waits for MMA
                write_output(output_stage)

comptime values

MbarPtr

comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]

Structs

Traits

  • TilePayload: Trait for tile payload types. Must be @register_passable("trivial").

Was this page helpful?