Skip to main content

Mojo struct

MmaWarp

struct MmaWarp[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]

Unified linear type for MMA warp TMEM lifecycle.

Works as both a linear type (direct use) and within context managers.

Lifecycle:

  1. Created via create() - allocates TMEM, signals sync barrier
  2. Use output_pipeline or acquire_k_stage_linear() for MMA stages
  3. Must call release() to wait for epilogue and deallocate (compiler-enforced)

Parameters​

  • ​opc (OutputPipelineConfig): Output pipeline configuration (stages, stride, cta_group).
  • ​mma_threads (Int): Number of MMA threads.
  • ​epilogue_threads (Int): Number of epilogue threads.

Fields​

  • ​tmem (MmaWarp[opc, mma_threads, epilogue_threads].Tmem):
  • ​output_pipeline (MmaWarp[opc, mma_threads, epilogue_threads].Pipeline):
  • ​dealloc_barrier (MmaWarp[opc, mma_threads, epilogue_threads].Dealloc):

Implemented traits​

AnyType

comptime members​

Dealloc​

comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc

Pipeline​

comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline

Sync​

comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync

Tmem​

comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem

Methods​

__init__​

__init__(out self, tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group])

create​

static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self

Create MMA warp with TMEM allocation.

Allocates TMEM and signals the warp group sync barrier.

Args:

Returns:

Self: Fully initialized MmaWarp that must be released.

per_k_stage​

per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), opc]

Get per-K stage context manager (for compatibility).

Prefer acquire_k_stage_linear() for flat code structure.

Returns:

MmaKStage[origin_of(self.output_pipeline), opc]

acquire_k_stage_linear​

acquire_k_stage_linear(mut self) -> MmaStage[origin_of(self.output_pipeline), opc]

Acquire a per-K stage using linear types.

Waits for epilogue to free the stage, returns a linear handle.

Usage: var stage = mma_handle.acquire_k_stage_linear() mma_op.mma(a, b, stage.tmem_offset()) mma_op.commit(stage.mbar()) stage^.release()

Returns:

MmaStage[origin_of(self.output_pipeline), opc]

release​

release(deinit self)

Wait for epilogue and deallocate TMEM.

This is the only way to destroy this linear type.