Skip to main content

Mojo struct

MmaWarpContext

struct MmaWarpContext[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]

MMA warp context - owns TMEM lifecycle and output pipeline.

enter: Signals epilogue that TMEM is allocated exit: Waits for epilogue, deallocates TMEM

Fields

  • tmem (MmaWarpContext[opc, mma_threads, epilogue_threads].Tmem):
  • output_pipeline (MmaWarpContext[opc, mma_threads, epilogue_threads].Pipeline):
  • dealloc_barrier (MmaWarpContext[opc, mma_threads, epilogue_threads].Dealloc):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

Dealloc

comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc

Pipeline

comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline

Sync

comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync

Tmem

comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem

Methods

__init__

__init__(tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group]) -> Self

create

static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self

Create MMA warp context with all necessary components.

Allocates TMEM and creates output pipeline internally.

Args:

  • tmem_addr_storage (SMemArray): Shared storage for TMEM address communication.
  • accum_barriers_ptr (UnsafePointer): Pointer to accumulator pipeline barriers.
  • dealloc_mbar (SMemArray): Barrier for TMEM deallocation synchronization.
  • mma_complete_mask (UInt16): Multicast mask for MMA completion signaling.

Returns:

Self: Fully initialized MmaWarpContext.

__enter__

__enter__(self) -> Self

__exit__

__exit__(self)

per_k_stage

per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), opc]

Get per-K stage for blockwise FP8 MMA loop.

Returns a context manager that acquires an output stage and signals mma_arrive on exit.

Example: for i in range(num_iters): with mma_ctx.per_k_stage() as mma_stage: mma(input_tiles, mma_op, AccumTensor(mma_stage.tmem.offset())) # exit signals mma_arrive automatically

Returns:

MmaKStage

Was this page helpful?