Mojo struct
EpilogueWarp
struct EpilogueWarp[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]
Unified linear type for epilogue warp lifecycle.
Works as both a linear type (direct use) and within context managers.
Lifecycle:
- Created via
create()after Sync.wait() - reads TMEM address - Use
output_pipelineoracquire_k_stage_linear()for epilogue stages - Must call
release()to signal completion (compiler-enforced)
IMPORTANT: Call Sync.wait() BEFORE create() to ensure TMEM address is visible.
Parametersβ
- βopc (
OutputPipelineConfig): Output pipeline configuration (stages, stride, cta_group). - βmma_threads (
Int): Number of MMA threads. - βepilogue_threads (
Int): Number of epilogue threads.
Fieldsβ
- βtmem (
EpilogueWarp[opc, mma_threads, epilogue_threads].Tmem): - βoutput_pipeline (
EpilogueWarp[opc, mma_threads, epilogue_threads].Pipeline): - βdealloc_barrier (
EpilogueWarp[opc, mma_threads, epilogue_threads].Dealloc):
Implemented traitsβ
comptime membersβ
Deallocβ
comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc
Pipelineβ
comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline
Syncβ
comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync
Tmemβ
comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem
Methodsβ
__init__β
__init__(out self, tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group])
createβ
static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self
Create Epilogue warp.
Reads TMEM address from shared memory. IMPORTANT: Call Sync.wait() BEFORE this to ensure the address is visible.
Args:
- βtmem_addr_storage (
SMemArray[UInt32, 1]): Shared storage containing TMEM address. - βaccum_barriers_ptr (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): Pointer to accumulator pipeline barriers. - βdealloc_mbar (
SMemArray[SharedMemBarrier, 1]): Barrier for TMEM deallocation synchronization. - βmma_complete_mask (
UInt16): Multicast mask for MMA completion signaling.
Returns:
Self: Fully initialized EpilogueWarp that must be released.
per_k_stageβ
per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[opc] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]
Get per-K stage context manager (for compatibility).
Prefer acquire_k_stage_linear() for flat code structure.
Returns:
EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]
acquire_k_stage_linearβ
acquire_k_stage_linear(mut self) -> EpilogueStage[origin_of(self.output_pipeline), opc]
Acquire a per-K stage using linear types.
Waits for MMA to complete the stage, returns a linear handle.
Usage: var stage = epi_handle.acquire_k_stage_linear() process_tmem(stage.tmem()) stage^.release()
Returns:
EpilogueStage[origin_of(self.output_pipeline), opc]
releaseβ
release(deinit self)
Signal epilogue completion.
This is the only way to destroy this linear type.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!