Mojo struct
EpilogueWarp
struct EpilogueWarp[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]
Unified linear type for epilogue warp lifecycle.
Works as both a linear type (direct use) and within context managers.
Lifecycle:
- Created via
create()after Sync.wait() - reads TMEM address - Use
output_pipelineoracquire_k_stage_linear()for epilogue stages - Must call
release()to signal completion (compiler-enforced)
IMPORTANT: Call Sync.wait() BEFORE create() to ensure TMEM address is visible.
Parameters
- num_accum_stages (
Int): Number of accumulator pipeline stages. - stage_stride_cols (
Int): TMEM column stride between stages. - cta_group (
Int): CTA group size (1 or 2). - mma_threads (
Int): Number of MMA threads. - epilogue_threads (
Int): Number of epilogue threads.
Fields
- tmem (
EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem): - output_pipeline (
EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline): - dealloc_barrier (
EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc):
Implemented traits
comptime members
Dealloc
comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc
Pipeline
comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline
Sync
comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync
Tmem
comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem
Methods
__init__
__init__(out self, tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group])
create
static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self
Create Epilogue warp.
Reads TMEM address from shared memory. IMPORTANT: Call Sync.wait() BEFORE this to ensure the address is visible.
Args:
- tmem_addr_storage (
SMemArray): Shared storage containing TMEM address. - accum_barriers (
SMemArray): Barrier array for accumulator pipeline. - dealloc_mbar (
SMemArray): Barrier for TMEM deallocation synchronization. - mma_complete_mask (
UInt16): Multicast mask for MMA completion signaling.
Returns:
Self: Fully initialized EpilogueWarp that must be released.
per_k_stage
per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[input_origin] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin._mlir_origin.pipeline), num_accum_stages, stage_stride_cols, cta_group, num_group_stages]
Get per-K stage context manager (for compatibility).
Prefer acquire_k_stage_linear() for flat code structure.
Returns:
EpilogueKContext
acquire_k_stage_linear
acquire_k_stage_linear(mut self) -> EpilogueStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]
Acquire a per-K stage using linear types.
Waits for MMA to complete the stage, returns a linear handle.
Usage: var stage = epi_handle.acquire_k_stage_linear() process_tmem(stage.tmem()) stage^.release()
Returns:
EpilogueStage
release
release(deinit self)
Signal epilogue completion.
This is the only way to destroy this linear type.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!