Skip to main content

Mojo module

epilogue_load_pipeline

Epilogue load pipeline types for SM100 Conv2D kernel.

This module provides runtime pipeline types for coordinating the epilogue load warp with other kernel components:

  • EpiLoadPipeline: Pipeline for EpilogueLoad warp → Epilogue warps transfer
  • LoadOrderBarrier: Barrier for MainLoad → EpilogueLoad coordination

Pipeline Architecture

The epilogue load warp (warp ID 7) pre-fetches source tensor C via TMA for residual operations (D = Conv(A,B) + beta*C). This overlaps C loading with MMA computation for better latency hiding.

MainLoad warp                 EpilogueLoad warp              Epilogue warps
    |                              |                              |
    |-- prologue loads --|         |                              |
    |                    |         |                              |
    |-- arrive() --------|-------->| wait()                       |
    |                              |                              |
    |-- steady-state     |         |-- TMA load C --|             |
    |                              |                |             |
    |                              |-- produce() ---|------------>| consume()
    |                              |                              |

Usage

Barrier Initialization

if elect_one_thread:
    load_order_barrier.init(arrive_count=1)  # MainLoad arrives
    epi_load_pipeline.init_barriers(
        producer_arv_count=1,    # EpilogueLoad (TMA)
        consumer_arv_count=128,  # Epilogue warps (4 × 32)
    )

MainLoad Warp (prologue/steady-state split)

if WarpRole.is_main_load():
    # Issue prologue loads
    for _ in range(num_prologue_stages):
        load_input_tiles(...)

    # Signal epilogue load can start
    load_order_barrier.arrive()

    # Continue with steady-state loads
    for _ in range(remaining_stages):
        load_input_tiles(...)

EpilogueLoad Warp

if WarpRole.is_epilogue_load():
    # Wait for mainloop to start
    load_order_barrier.wait()

    with epi_load_pipeline.produce() as stage:
        # Load C tile via TMA
        tma_load(c_tile, stage.mbar(), ...)

Epilogue Warps

if WarpRole.is_epilogue():
    with epi_load_pipeline.consume() as c_stage:
        # C tile now in SMEM
        c_tile = smem.src_tiles()[c_stage.index()]
        tile_writer.write_with_residual(accum, c_tile, beta, ...)

comptime values

UnsafePointer

comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]

Structs

Was this page helpful?