Mojo struct
MmaStage
struct MmaStage[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]
Unified linear type handle for MMA stage in output pipeline.
Works as both a linear type (direct use) and within context managers.
Lifecycle:
- Created via
output_pipeline.acquire_mma_linear()- waits for epilogue - Use
tmem(),tmem_offset(),mbar()for MMA operations - Must call
release()to signal mma_arrive and advance (compiler-enforced)
Parameters
- origin (
MutOrigin): Origin of the pipeline reference. - num_stages (
Int): Number of pipeline stages. - stage_stride_cols (
Int): TMEM column stride between stages. - cta_group (
Int): CTA group size (1 or 2).
Fields
- pipeline_ptr (
Pointer[MmaStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]):
Implemented traits
comptime members
Stage
comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]
TilePipelineType
comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]
Methods
__init__
__init__(out self, pipeline_ptr: Pointer[MmaStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin], stage: OutputStage[num_stages, stage_stride_cols, cta_group])
tmem
tmem(self) -> MmaStage[origin, num_stages, stage_stride_cols, cta_group].Stage.Tmem
Get the TMEM stage handle.
Returns:
MmaStage
tmem_offset
index
mbar
mbar(self) -> MbarPtr
Get the producer barrier for MMA commit.
Returns:
MbarPtr
release
release(deinit self)
Signal MMA completion and advance to next stage.
This is the only way to destroy this linear type. Internally calls mma_arrive (1-SM) or mma_arrive_multicast (2-SM).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!