Mojo struct
EpilogueWarpContext
struct EpilogueWarpContext[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]
Epilogue warp context - consumes TMEM data, signals completion.
IMPORTANT: Call Sync.wait() BEFORE constructing to ensure TMEM address is visible from shared memory.
Fieldsβ
- βtmem (
EpilogueWarpContext[opc, mma_threads, epilogue_threads].Tmem): - βoutput_pipeline (
EpilogueWarpContext[opc, mma_threads, epilogue_threads].Pipeline): - βdealloc_barrier (
EpilogueWarpContext[opc, mma_threads, epilogue_threads].Dealloc):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
Deallocβ
comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc
Pipelineβ
comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline
Syncβ
comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync
Tmemβ
comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem
Methodsβ
__init__β
__init__(tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group]) -> Self
createβ
static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self
Create Epilogue warp context with all necessary components.
Reads TMEM address from shared memory and creates output pipeline. IMPORTANT: Call Sync.wait() BEFORE calling this to ensure TMEM address is visible.
Args:
- βtmem_addr_storage (
SMemArray[UInt32, 1]): Shared storage containing TMEM address. - βaccum_barriers_ptr (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): Pointer to accumulator pipeline barriers. - βdealloc_mbar (
SMemArray[SharedMemBarrier, 1]): Barrier for TMEM deallocation synchronization. - βmma_complete_mask (
UInt16): Multicast mask for MMA completion signaling.
Returns:
Self: Fully initialized EpilogueWarpContext.
__enter__β
__enter__(self) -> Self
__exit__β
__exit__(self)
per_k_stageβ
per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[opc] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]
Get per-K stage context for blockwise FP8 epilogue.
Bundles output pipeline (MMAβEpilogue sync) and input pipeline (A-scales consumption) into a single context manager.
Example: for k_iter in range(num_iters): with epi_ctx.per_k_stage(input_pipeline) as epi_stage: accum.promote(epi_stage, ...) # Both pipelines signaled automatically
Args:
- βinput_pipeline (
InputTilePipeline[Payload, num_group_stages, k_group_size]): The InputTilePipeline (extracts .pipeline internally).
Returns:
EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin.pipeline), opc, num_group_stages]: EpilogueKContext context manager that handles both pipelines.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!