IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

EpilogueWarpContext

struct EpilogueWarpContext[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]

Epilogue warp context - consumes TMEM data, signals completion.

IMPORTANT: Call Sync.wait() BEFORE constructing to ensure TMEM address is visible from shared memory.

Fields​

  • ​tmem (EpilogueWarpContext[opc, mma_threads, epilogue_threads].Tmem):
  • ​output_pipeline (EpilogueWarpContext[opc, mma_threads, epilogue_threads].Pipeline):
  • ​dealloc_barrier (EpilogueWarpContext[opc, mma_threads, epilogue_threads].Dealloc):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

Dealloc​

comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc

Pipeline​

comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline

Sync​

comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync

Tmem​

comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem

Methods​

__init__​

def __init__(tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group]) -> Self

create​

static def create(tmem_addr_storage: SMemArray[UInt32, Int(1)], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, Int(1)], mma_complete_mask: UInt16) -> Self

Create Epilogue warp context with all necessary components.

Reads TMEM address from shared memory and creates output pipeline. IMPORTANT: Call Sync.wait() BEFORE calling this to ensure TMEM address is visible.

Args:

Returns:

Self: Fully initialized EpilogueWarpContext.

__enter__​

def __enter__(self) -> Self

__exit__​

def __exit__(self)

per_k_stage​

def per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[opc] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]

Get per-K stage context for blockwise FP8 epilogue.

Bundles output pipeline (MMA→Epilogue sync) and input pipeline (A-scales consumption) into a single context manager.

Example: for k_iter in range(num_iters): with epi_ctx.per_k_stage(input_pipeline) as epi_stage: accum.promote(epi_stage, ...) # Both pipelines signaled automatically

Args:

Returns:

EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]: EpilogueKContext context manager that handles both pipelines.