Skip to main content

Mojo struct

EpilogueWarpContext

@register_passable(trivial) struct EpilogueWarpContext[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]

Epilogue warp context - consumes TMEM data, signals completion.

IMPORTANT: Call Sync.wait() BEFORE constructing to ensure TMEM address is visible from shared memory.

Fields

  • tmem (EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem):
  • output_pipeline (EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline):
  • dealloc_barrier (EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Dealloc

comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc

Pipeline

comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline

Sync

comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync

Tmem

comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem

Methods

__init__

__init__(tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group]) -> Self

__enter__

__enter__(self) -> Self

__exit__

__exit__(self)

Was this page helpful?