Skip to main content

Mojo struct

EpilogueWarp

struct EpilogueWarp[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]

Unified linear type for epilogue warp lifecycle.

Works as both a linear type (direct use) and within context managers.

Lifecycle:

  1. Created via create() after Sync.wait() - reads TMEM address
  2. Use output_pipeline or acquire_k_stage_linear() for epilogue stages
  3. Must call release() to signal completion (compiler-enforced)

IMPORTANT: Call Sync.wait() BEFORE create() to ensure TMEM address is visible.

Parameters

  • opc (OutputPipelineConfig): Output pipeline configuration (stages, stride, cta_group).
  • mma_threads (Int): Number of MMA threads.
  • epilogue_threads (Int): Number of epilogue threads.

Fields

  • tmem (EpilogueWarp[opc, mma_threads, epilogue_threads].Tmem):
  • output_pipeline (EpilogueWarp[opc, mma_threads, epilogue_threads].Pipeline):
  • dealloc_barrier (EpilogueWarp[opc, mma_threads, epilogue_threads].Dealloc):

Implemented traits

AnyType

comptime members

Dealloc

comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc

Pipeline

comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline

Sync

comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync

Tmem

comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem

Methods

__init__

__init__(out self, tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group])

create

static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self

Create Epilogue warp.

Reads TMEM address from shared memory. IMPORTANT: Call Sync.wait() BEFORE this to ensure the address is visible.

Args:

  • tmem_addr_storage (SMemArray): Shared storage containing TMEM address.
  • accum_barriers_ptr (UnsafePointer): Pointer to accumulator pipeline barriers.
  • dealloc_mbar (SMemArray): Barrier for TMEM deallocation synchronization.
  • mma_complete_mask (UInt16): Multicast mask for MMA completion signaling.

Returns:

Self: Fully initialized EpilogueWarp that must be released.

per_k_stage

per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[opc] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]

Get per-K stage context manager (for compatibility).

Prefer acquire_k_stage_linear() for flat code structure.

Returns:

EpilogueKContext

acquire_k_stage_linear

acquire_k_stage_linear(mut self) -> EpilogueStage[origin_of(self.output_pipeline), opc]

Acquire a per-K stage using linear types.

Waits for MMA to complete the stage, returns a linear handle.

Usage: var stage = epi_handle.acquire_k_stage_linear() process_tmem(stage.tmem()) stage^.release()

Returns:

EpilogueStage

release

release(deinit self)

Signal epilogue completion.

This is the only way to destroy this linear type.

Was this page helpful?