Skip to main content

Mojo struct

EpilogueKStage

@register_passable(trivial) struct EpilogueKStage[num_output_stages: Int, stage_stride_cols: Int, cta_group: Int, num_input_stages: Int]

Per-K stage for epilogue warp in blockwise FP8.

Returned from EpilogueKContext.__enter__(). Bundles:

  • output_stage: TMEM access (offset for reading MMA results)
  • input_stage_index: Current A-scales stage
  • input_pipeline: For signaling A-scales consumption

Fields

  • output_stage (EpilogueKStage[num_output_stages, stage_stride_cols, cta_group, num_input_stages].OutputStageType):
  • input_stage_index (UInt32):
  • input_pipeline (EpilogueKStage[num_output_stages, stage_stride_cols, cta_group, num_input_stages].InputPipelineType):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

InputPipelineType

comptime InputPipelineType = ProducerConsumerPipeline[num_input_stages]

OutputStageType

comptime OutputStageType = OutputStage[num_output_stages, stage_stride_cols, cta_group]

Methods

__init__

__init__(output_stage: OutputStage[num_output_stages, stage_stride_cols, cta_group], input_stage_index: UInt32, input_pipeline: ProducerConsumerPipeline[num_input_stages]) -> Self

arrive_input

arrive_input(self)

Arrive on the input pipeline's consumer barrier.

Use with lane-guarded patterns: if lane_id() < cluster_size: epi_stage.arrive_input()

Was this page helpful?