IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MmaWarpContext

struct MmaWarpContext[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]

MMA warp context - owns TMEM lifecycle and output pipeline.

enter: Signals epilogue that TMEM is allocated exit: Waits for epilogue, deallocates TMEM

Fields​

  • ​tmem (MmaWarpContext[opc, mma_threads, epilogue_threads].Tmem):
  • ​output_pipeline (MmaWarpContext[opc, mma_threads, epilogue_threads].Pipeline):
  • ​dealloc_barrier (MmaWarpContext[opc, mma_threads, epilogue_threads].Dealloc):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

Dealloc​

comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc

Pipeline​

comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline

Sync​

comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync

Tmem​

comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem

Methods​

__init__​

def __init__(tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group]) -> Self

create​

static def create(tmem_addr_storage: SMemArray[UInt32, Int(1)], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, Int(1)], mma_complete_mask: UInt16) -> Self

Create MMA warp context with all necessary components.

Allocates TMEM and creates output pipeline internally.

Args:

Returns:

Self: Fully initialized MmaWarpContext.

__enter__​

def __enter__(self) -> Self

__exit__​

def __exit__(self)

per_k_stage​

def per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), opc]

Get per-K stage for blockwise FP8 MMA loop.

Returns a context manager that acquires an output stage and signals mma_arrive on exit.

Example: for i in range(num_iters): with mma_ctx.per_k_stage() as mma_stage: mma(input_tiles, mma_op, AccumTensor(mma_stage.tmem.offset())) # exit signals mma_arrive automatically

Returns:

MmaKStage[origin_of(self.output_pipeline), opc]