Mojo struct
EpilogueWarpContext
@register_passable(trivial)
struct EpilogueWarpContext[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]
Epilogue warp context - consumes TMEM data, signals completion.
IMPORTANT: Call Sync.wait() BEFORE constructing to ensure TMEM address is visible from shared memory.
Fields
- tmem (
EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem): - output_pipeline (
EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline): - dealloc_barrier (
EpilogueWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Dealloc
comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc
Pipeline
comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline
Sync
comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync
Tmem
comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem
Methods
__init__
__init__(tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group]) -> Self
__enter__
__enter__(self) -> Self
__exit__
__exit__(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!