Skip to main content

Mojo struct

MmaWarp

struct MmaWarp[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]

Unified linear type for MMA warp TMEM lifecycle.

Works as both a linear type (direct use) and within context managers.

Lifecycle:

  1. Created via create() - allocates TMEM, signals sync barrier
  2. Use output_pipeline or acquire_k_stage_linear() for MMA stages
  3. Must call release() to wait for epilogue and deallocate (compiler-enforced)

Parameters

  • num_accum_stages (Int): Number of accumulator pipeline stages.
  • stage_stride_cols (Int): TMEM column stride between stages.
  • cta_group (Int): CTA group size (1 or 2).
  • mma_threads (Int): Number of MMA threads.
  • epilogue_threads (Int): Number of epilogue threads.

Fields

  • tmem (MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem):
  • output_pipeline (MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline):
  • dealloc_barrier (MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc):

Implemented traits

AnyType

comptime members

Dealloc

comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc

Pipeline

comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline

Sync

comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync

Tmem

comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem

Methods

__init__

__init__(out self, tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group])

create

static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self

Create MMA warp with TMEM allocation.

Allocates TMEM and signals the warp group sync barrier.

Args:

  • tmem_addr_storage (SMemArray): Shared storage for TMEM address communication.
  • accum_barriers (SMemArray): Barrier array for accumulator pipeline.
  • dealloc_mbar (SMemArray): Barrier for TMEM deallocation synchronization.
  • mma_complete_mask (UInt16): Multicast mask for MMA completion signaling.

Returns:

Self: Fully initialized MmaWarp that must be released.

per_k_stage

per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]

Get per-K stage context manager (for compatibility).

Prefer acquire_k_stage_linear() for flat code structure.

Returns:

MmaKStage

acquire_k_stage_linear

acquire_k_stage_linear(mut self) -> MmaStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]

Acquire a per-K stage using linear types.

Waits for epilogue to free the stage, returns a linear handle.

Usage: var stage = mma_handle.acquire_k_stage_linear() mma_op.mma(a, b, stage.tmem_offset()) mma_op.commit(stage.mbar()) stage^.release()

Returns:

MmaStage

release

release(deinit self)

Wait for epilogue and deallocate TMEM.

This is the only way to destroy this linear type.

Was this page helpful?