Mojo struct
MmaWarpContext
@register_passable(trivial)
struct MmaWarpContext[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]
MMA warp context - owns TMEM lifecycle and output pipeline.
enter: Signals epilogue that TMEM is allocated exit: Waits for epilogue, deallocates TMEM
Fields
- tmem (
MmaWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem): - output_pipeline (
MmaWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline): - dealloc_barrier (
MmaWarpContext[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Dealloc
comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc
Pipeline
comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline
Sync
comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync
Tmem
comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem
Methods
__init__
__init__(tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group]) -> Self
create
static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self
Create MMA warp context with all necessary components.
Allocates TMEM and creates output pipeline internally.
Args:
- tmem_addr_storage (
SMemArray): Shared storage for TMEM address communication. - accum_barriers (
SMemArray): Barrier array for accumulator pipeline. - dealloc_mbar (
SMemArray): Barrier for TMEM deallocation synchronization. - mma_complete_mask (
UInt16): Multicast mask for MMA completion signaling.
Returns:
Self: Fully initialized MmaWarpContext.
__enter__
__enter__(self) -> Self
__exit__
__exit__(self)
per_k_stage
per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]
Get per-K stage for blockwise FP8 MMA loop.
Returns a context manager that acquires an output stage and signals mma_arrive on exit.
Example: for i in range(num_iters): with mma_ctx.per_k_stage() as mma_stage: mma(input_tiles, mma_op, AccumTensor(mma_stage.tmem.offset())) # exit signals mma_arrive automatically
Returns:
MmaKStage
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!