Skip to main content

Mojo struct

MmaStage

struct MmaStage[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]

Unified linear type handle for MMA stage in output pipeline.

Works as both a linear type (direct use) and within context managers.

Lifecycle:

  1. Created via output_pipeline.acquire_mma_linear() - waits for epilogue
  2. Use tmem(), tmem_offset(), mbar() for MMA operations
  3. Must call release() to signal mma_arrive and advance (compiler-enforced)

Parameters

  • origin (MutOrigin): Origin of the pipeline reference.
  • num_stages (Int): Number of pipeline stages.
  • stage_stride_cols (Int): TMEM column stride between stages.
  • cta_group (Int): CTA group size (1 or 2).

Fields

  • pipeline_ptr (Pointer[MmaStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]):

Implemented traits

AnyType

comptime members

Stage

comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]

TilePipelineType

comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]

Methods

__init__

__init__(out self, pipeline_ptr: Pointer[MmaStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin], stage: OutputStage[num_stages, stage_stride_cols, cta_group])

tmem

tmem(self) -> MmaStage[origin, num_stages, stage_stride_cols, cta_group].Stage.Tmem

Get the TMEM stage handle.

Returns:

MmaStage

tmem_offset

tmem_offset(self) -> Int

Get the TMEM offset for MMA accumulator.

Returns:

Int

index

index(self) -> UInt32

Get the current stage index.

Returns:

UInt32

mbar

mbar(self) -> MbarPtr

Get the producer barrier for MMA commit.

Returns:

MbarPtr

release

release(deinit self)

Signal MMA completion and advance to next stage.

This is the only way to destroy this linear type. Internally calls mma_arrive (1-SM) or mma_arrive_multicast (2-SM).

Was this page helpful?