Skip to main content

Mojo struct

MmaKStage

@register_passable(trivial) struct MmaKStage[_mlir_origin: LITMutOrigin, //, origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]

Per-K stage context for MMA warp in blockwise FP8.

enter: Acquires stage, waits for epilogue to release previous stage exit: Signals mma_arrive to notify epilogue, advances producer stage

Fields

  • pipeline_ptr (Pointer[MmaKStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]):
  • stage (MmaKStage[origin, num_stages, stage_stride_cols, cta_group].Stage):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Stage

comptime Stage = OutputStage[num_stages, stage_stride_cols, cta_group]

TilePipelineType

comptime TilePipelineType = OutputTilePipeline[num_stages, stage_stride_cols, cta_group]

Methods

__init__

__init__(pipeline_ptr: Pointer[MmaKStage[origin, num_stages, stage_stride_cols, cta_group].TilePipelineType, origin]) -> Self

__enter__

__enter__(mut self) -> MmaKStage[origin, num_stages, stage_stride_cols, cta_group].Stage

Returns:

MmaKStage

__exit__

__exit__(mut self)

Was this page helpful?