IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MmaWarp

struct MmaWarp[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]

Unified linear type for MMA warp TMEM lifecycle.

Works as both a linear type (direct use) and within context managers.

Lifecycle:

  1. Created via create() - allocates TMEM, signals sync barrier
  2. Use output_pipeline or acquire_k_stage_linear() for MMA stages
  3. Must call release() to wait for epilogue and deallocate (compiler-enforced)

Parameters​

  • ​opc (OutputPipelineConfig): Output pipeline configuration (stages, stride, cta_group).
  • ​mma_threads (Int): Number of MMA threads.
  • ​epilogue_threads (Int): Number of epilogue threads.

Fields​

  • ​tmem (MmaWarp[opc, mma_threads, epilogue_threads].Tmem):
  • ​output_pipeline (MmaWarp[opc, mma_threads, epilogue_threads].Pipeline):
  • ​dealloc_barrier (MmaWarp[opc, mma_threads, epilogue_threads].Dealloc):

Implemented traits​

AnyType

comptime members​

Dealloc​

comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc

Pipeline​

comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline

Sync​

comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync

Tmem​

comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem

Methods​

__init__​

def __init__(out self, tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group])

create​

static def create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self

Create MMA warp with TMEM allocation.

Allocates TMEM and signals the warp group sync barrier.

Args:

Returns:

Self: Fully initialized MmaWarp that must be released.

per_k_stage​

def per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), opc]

Get per-K stage context manager (for compatibility).

Prefer acquire_k_stage_linear() for flat code structure.

Returns:

MmaKStage[origin_of(self.output_pipeline), opc]

acquire_k_stage_linear​

def acquire_k_stage_linear(mut self) -> MmaStage[origin_of(self.output_pipeline), opc]

Acquire a per-K stage using linear types.

Waits for epilogue to free the stage, returns a linear handle.

Usage: var stage = mma_handle.acquire_k_stage_linear() mma_op.mma(a, b, stage.tmem_offset()) mma_op.commit(stage.mbar()) stage^.release()

Returns:

MmaStage[origin_of(self.output_pipeline), opc]

release​

def release(deinit self)

Wait for epilogue and deallocate TMEM.

This is the only way to destroy this linear type.