Mojo struct
MmaWarp
struct MmaWarp[opc: OutputPipelineConfig, mma_threads: Int, epilogue_threads: Int]
Unified linear type for MMA warp TMEM lifecycle.
Works as both a linear type (direct use) and within context managers.
Lifecycle:
- Created via
create()- allocates TMEM, signals sync barrier - Use
output_pipelineoracquire_k_stage_linear()for MMA stages - Must call
release()to wait for epilogue and deallocate (compiler-enforced)
Parametersβ
- βopc (
OutputPipelineConfig): Output pipeline configuration (stages, stride, cta_group). - βmma_threads (
Int): Number of MMA threads. - βepilogue_threads (
Int): Number of epilogue threads.
Fieldsβ
- βtmem (
MmaWarp[opc, mma_threads, epilogue_threads].Tmem): - βoutput_pipeline (
MmaWarp[opc, mma_threads, epilogue_threads].Pipeline): - βdealloc_barrier (
MmaWarp[opc, mma_threads, epilogue_threads].Dealloc):
Implemented traitsβ
comptime membersβ
Deallocβ
comptime Dealloc = _WarpContextTypes[opc, mma_threads, epilogue_threads].Dealloc
Pipelineβ
comptime Pipeline = _WarpContextTypes[opc, mma_threads, epilogue_threads].Pipeline
Syncβ
comptime Sync = _WarpContextTypes[opc, mma_threads, epilogue_threads].Sync
Tmemβ
comptime Tmem = _WarpContextTypes[opc, mma_threads, epilogue_threads].Tmem
Methodsβ
__init__β
__init__(out self, tmem: TmemAllocation[opc.cta_group], output_pipeline: OutputTilePipeline[opc], dealloc_barrier: TmemDeallocBarrier[opc.cta_group])
createβ
static create(tmem_addr_storage: SMemArray[UInt32, 1], accum_barriers_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], dealloc_mbar: SMemArray[SharedMemBarrier, 1], mma_complete_mask: UInt16) -> Self
Create MMA warp with TMEM allocation.
Allocates TMEM and signals the warp group sync barrier.
Args:
- βtmem_addr_storage (
SMemArray[UInt32, 1]): Shared storage for TMEM address communication. - βaccum_barriers_ptr (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): Pointer to accumulator pipeline barriers. - βdealloc_mbar (
SMemArray[SharedMemBarrier, 1]): Barrier for TMEM deallocation synchronization. - βmma_complete_mask (
UInt16): Multicast mask for MMA completion signaling.
Returns:
Self: Fully initialized MmaWarp that must be released.
per_k_stageβ
per_k_stage(mut self) -> MmaKStage[origin_of(self.output_pipeline), opc]
Get per-K stage context manager (for compatibility).
Prefer acquire_k_stage_linear() for flat code structure.
Returns:
MmaKStage[origin_of(self.output_pipeline), opc]
acquire_k_stage_linearβ
acquire_k_stage_linear(mut self) -> MmaStage[origin_of(self.output_pipeline), opc]
Acquire a per-K stage using linear types.
Waits for epilogue to free the stage, returns a linear handle.
Usage: var stage = mma_handle.acquire_k_stage_linear() mma_op.mma(a, b, stage.tmem_offset()) mma_op.commit(stage.mbar()) stage^.release()
Returns:
MmaStage[origin_of(self.output_pipeline), opc]
releaseβ
release(deinit self)
Wait for epilogue and deallocate TMEM.
This is the only way to destroy this linear type.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!