Mojo struct
EpilogueApplier
@register_passable(trivial)
struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]
Apply element-wise epilogue operations on register fragments.
Computes global coordinates for each element and applies a lambda function. Handles different MMA layouts (A/B/D/F) and transpose modes.
Template Parameters: MMA_M: MMA M dimension. stageN: Stage width in elements. num_stages: Number of output stages. repeats: Number of repetitions per load. cta_group: Number of CTAs cooperating (1 or 2). transpose_c: Whether output is transposed.
Fields
- coords (
EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords): - warp_id (
UInt32): - lane_id (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Coords
comptime Coords = FragmentCoords[stageN, repeats]
Methods
__init__
__init__(warp_id: UInt32, lane_id: UInt32) -> Self
Initialize the epilogue applier.
Args:
compute_staged_coords
compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]
Compute staged row and column coordinates.
Args:
- stage (
UInt32): Current stage index. - c_row (
UInt32): Base row coordinate. - c_col (
UInt32): Base column coordinate.
Returns:
Tuple: Tuple of (staged_row, staged_col).
apply_to_fragment
apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type](self, mut frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)
Apply epilogue lambda to a fragment.
Args:
apply_to_both_fragments
apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type, is_lower_frag_required: Bool](self, mut upper_frag: SIMD[epilogue_dtype, frag_size], mut lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[SIMD[epilogue_dtype, frag_size], SIMD[epilogue_dtype, frag_size]]
Apply epilogue lambda to both upper and lower fragments.
This is the main entry point for register-based epilogue, replacing the standalone register_epilogue function.
Args:
- upper_frag (
SIMD): Upper fragment to apply epilogue to. - lower_frag (
SIMD): Lower fragment to apply epilogue to. - stage (
UInt32): Current stage index. - c_row (
UInt32): Base row coordinate. - c_col (
UInt32): Base column coordinate.
Returns:
Tuple: Tuple of (modified upper_frag, modified lower_frag).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!