Skip to main content

Mojo struct

EpilogueApplier

@register_passable(trivial) struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]

Apply element-wise epilogue operations on register fragments.

Computes global coordinates for each element and applies a lambda function. Handles different MMA layouts (A/B/D/F) and transpose modes.

Template Parameters: MMA_M: MMA M dimension. stageN: Stage width in elements. num_stages: Number of output stages. repeats: Number of repetitions per load. cta_group: Number of CTAs cooperating (1 or 2). transpose_c: Whether output is transposed.

Fields

  • coords (EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords):
  • warp_id (UInt32):
  • lane_id (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Coords

comptime Coords = FragmentCoords[stageN, repeats]

Methods

__init__

__init__(warp_id: UInt32, lane_id: UInt32) -> Self

Initialize the epilogue applier.

Args:

  • warp_id (UInt32): Warp ID within the CTA.
  • lane_id (UInt32): Lane ID within the warp.

compute_staged_coords

compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]

Compute staged row and column coordinates.

Args:

  • stage (UInt32): Current stage index.
  • c_row (UInt32): Base row coordinate.
  • c_col (UInt32): Base column coordinate.

Returns:

Tuple: Tuple of (staged_row, staged_col).

apply_to_fragment

apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type](self, mut frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)

Apply epilogue lambda to a fragment.

Args:

  • frag (SIMD): Fragment to apply epilogue to (modified in place).
  • staged_row (UInt32): Staged row coordinate.
  • staged_col (UInt32): Staged column coordinate.
  • is_upper (Bool): Whether this is the upper or lower fragment.

apply_to_both_fragments

apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type, is_lower_frag_required: Bool](self, mut upper_frag: SIMD[epilogue_dtype, frag_size], mut lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[SIMD[epilogue_dtype, frag_size], SIMD[epilogue_dtype, frag_size]]

Apply epilogue lambda to both upper and lower fragments.

This is the main entry point for register-based epilogue, replacing the standalone register_epilogue function.

Args:

  • upper_frag (SIMD): Upper fragment to apply epilogue to.
  • lower_frag (SIMD): Lower fragment to apply epilogue to.
  • stage (UInt32): Current stage index.
  • c_row (UInt32): Base row coordinate.
  • c_col (UInt32): Base column coordinate.

Returns:

Tuple: Tuple of (modified upper_frag, modified lower_frag).

Was this page helpful?