Mojo struct
MmaWarp
struct MmaWarp[num_accum_stages: Int, stage_stride_cols: Int, cta_group: Int, mma_threads: Int, epilogue_threads: Int]
Unified linear type for MMA warp TMEM lifecycle.
Works as both a linear type (direct use) and within context managers.
Lifecycle:
- Created via
create()- allocates TMEM, signals sync barrier - Use
output_pipelineoracquire_k_stage_linear()for MMA stages - Must call
release()to wait for epilogue and deallocate (compiler-enforced)
Parameters
- num_accum_stages (
Int): Number of accumulator pipeline stages. - stage_stride_cols (
Int): TMEM column stride between stages. - cta_group (
Int): CTA group size (1 or 2). - mma_threads (
Int): Number of MMA threads. - epilogue_threads (
Int): Number of epilogue threads.
Fields
- tmem (
MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem): - output_pipeline (
MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline): - dealloc_barrier (
MmaWarp[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc):
Implemented traits
comptime members
Dealloc
comptime Dealloc = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Dealloc
Pipeline
comptime Pipeline = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Pipeline
Sync
comptime Sync = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Sync
Tmem
comptime Tmem = _WarpContextTypes[num_accum_stages, stage_stride_cols, cta_group, mma_threads, epilogue_threads].Tmem
Methods
__init__
__init__(out self, tmem: TmemAllocation[cta_group], output_pipeline: OutputTilePipeline[num_accum_stages, stage_stride_cols, cta_group], dealloc_barrier: TmemDeallocBarrier[cta_group])
create
static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers: SMemArray[SharedMemBarrier, (num_accum_stages * 2)], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self
Create MMA warp with TMEM allocation.
Allocates TMEM and signals the warp group sync barrier.
Args:
- tmem_addr_storage (
SMemArray): Shared storage for TMEM address communication. - accum_barriers (
SMemArray): Barrier array for accumulator pipeline. - dealloc_mbar (
SMemArray): Barrier for TMEM deallocation synchronization. - mma_complete_mask (
UInt16): Multicast mask for MMA completion signaling.
Returns:
Self: Fully initialized MmaWarp that must be released.
per_k_stage
per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]
Get per-K stage context manager (for compatibility).
Prefer acquire_k_stage_linear() for flat code structure.
Returns:
MmaKStage
acquire_k_stage_linear
acquire_k_stage_linear(mut self) -> MmaStage[origin_of(self.output_pipeline), num_accum_stages, stage_stride_cols, cta_group]
Acquire a per-K stage using linear types.
Waits for epilogue to free the stage, returns a linear handle.
Usage: var stage = mma_handle.acquire_k_stage_linear() mma_op.mma(a, b, stage.tmem_offset()) mma_op.commit(stage.mbar()) stage^.release()
Returns:
MmaStage
release
release(deinit self)
Wait for epilogue and deallocate TMEM.
This is the only way to destroy this linear type.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!