Skip to main content

Mojo struct

EpilogueWarp

struct EpilogueWarp[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]

Unified linear type for epilogue warp lifecycle.

Works as both a linear type (direct use) and within context managers.

Lifecycle:

  1. Created via create() after Sync.wait() - reads TMEM address
  2. Use output_pipeline or acquire_k_stage_linear() for epilogue stages
  3. Must call release() to signal completion (compiler-enforced)

IMPORTANT: Call Sync.wait() BEFORE create() to ensure TMEM address is visible.

Parameters

  • num_accum_stages (Int): Number of accumulator pipeline stages.
  • stage_stride_cols (Int): TMEM column stride between stages.
  • cta_group (Int): CTA group size (1 or 2).
  • mma_threads (Int): Number of MMA threads.
  • epilogue_threads (Int): Number of epilogue threads.

Fields

  • tmem (EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem):
  • output_pipeline (EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline):
  • dealloc_barrier (EpilogueWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc):

Implemented traits

AnyType

comptime members

Dealloc

comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc

Pipeline

comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline

Sync

comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync

Tmem

comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem

Methods

__init__

__init__(out self, tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group])

create

static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self

Create Epilogue warp.

Reads TMEM address from shared memory. IMPORTANT: Call Sync.wait() BEFORE this to ensure the address is visible.

Args:

  • tmem_addr_storage (SMemArray): Shared storage containing TMEM address.
  • accum_barriers (SMemArray): Barrier array for accumulator pipeline.
  • dealloc_mbar (SMemArray): Barrier for TMEM deallocation synchronization.
  • mma_complete_mask (UInt16): Multicast mask for MMA completion signaling.

Returns:

Self: Fully initialized EpilogueWarp that must be released.

per_k_stage

per_k_stage[input_origin: MutOrigin, Payload: TilePayload, num_group_stages: Int, k_group_size: Int](mut self, ref[input_origin] input_pipeline: InputTilePipeline[Payload, num_group_stages, k_group_size]) -> EpilogueKContext[origin_of(self.output_pipeline), origin_of(input_origin._mlir_origin.pipeline), num_accum_stages, stage_stride_cols, cta_group, num_group_stages]

Get per-K stage context manager (for compatibility).

Prefer acquire_k_stage_linear() for flat code structure.

Returns:

EpilogueKContext

acquire_k_stage_linear

acquire_k_stage_linear(mut self) -> EpilogueStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]

Acquire a per-K stage using linear types.

Waits for MMA to complete the stage, returns a linear handle.

Usage: var stage = epi_handle.acquire_k_stage_linear() process_tmem(stage.tmem()) stage^.release()

Returns:

EpilogueStage

release

release(deinit self)

Signal epilogue completion.

This is the only way to destroy this linear type.

Was this page helpful?