IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

EpilogueWarp

struct EpilogueWarp[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]

Unified linear type for epilogue warp lifecycle.

Works as both a linear type (direct use) and within context managers.

Lifecycle:

  1. Created via create() after Sync.wait() - reads TMEM address
  2. Use output_pipeline or acquire_k_stage_linear() for epilogue stages
  3. Must call release() to signal completion (compiler-enforced)

IMPORTANT: Call Sync.wait() BEFORE create() to ensure TMEM address is visible.

Parameters​

  • ​opc (OutputPipelineConfig): Output pipeline configuration (stages, stride, cta_group).
  • ​mma_threads (Int): Number of MMA threads.
  • ​epilogue_threads (Int): Number of epilogue threads.

Fields​

  • ​tmem (EpilogueWarp[opc, mma_threads, epilogue_threads].Tmem):
  • ​output_pipeline (EpilogueWarp[opc, mma_threads, epilogue_threads].Pipeline):
  • ​dealloc_barrier (EpilogueWarp[opc, mma_threads, epilogue_threads].Dealloc):

Implemented traits​

AnyType

comptime members​

Dealloc​

comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc

Pipeline​

comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline

Sync​

comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync

Tmem​

comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem

Methods​

__init__​

def __init__(out self, tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group])

create​

static def create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self

Create Epilogue warp.

Reads TMEM address from shared memory. IMPORTANT: Call Sync.wait() BEFORE this to ensure the address is visible.

Args:

Returns:

Self: Fully initialized EpilogueWarp that must be released.

per_k_stage​

def per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[opc] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]

Get per-K stage context manager (for compatibility).

Prefer acquire_k_stage_linear() for flat code structure.

Returns:

EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]

acquire_k_stage_linear​

def acquire_k_stage_linear(mut self) -> EpilogueStage[origin_of(self.output_pipeline), opc]

Acquire a per-K stage using linear types.

Waits for MMA to complete the stage, returns a linear handle.

Usage: var stage = epi_handle.acquire_k_stage_linear() process_tmem(stage.tmem()) stage^.release()

Returns:

EpilogueStage[origin_of(self.output_pipeline), opc]

release​

def release(deinit self)

Signal epilogue completion.

This is the only way to destroy this linear type.