Skip to main content

Mojo struct

EpilogueApplier

struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]

Apply element-wise epilogue lambda to register fragments.

Fields​

  • ​coords (EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords):
  • ​warp_id (UInt32):
  • ​lane_id (UInt32):
  • ​M (UInt32):
  • ​N (UInt32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

Coords​

comptime Coords = FragmentCoords[stageN, repeats]

Methods​

__init__​

__init__(warp_id: UInt32, lane_id: UInt32, c_shape: Tuple[UInt32, UInt32]) -> Self

compute_staged_coords​

compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]

Compute global coords with warp and stage offsets (layout-dependent).

Returns:

Tuple[UInt32, UInt32]

apply_to_fragment​

apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]](self, mut frag: InlineArray[Scalar[epilogue_dtype], frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)

Apply epilogue lambda to fragment elements with global coords.

apply_to_both_fragments​

apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width], is_lower_frag_required: Bool](self, mut upper_frag: InlineArray[Scalar[epilogue_dtype], frag_size], mut lower_frag: InlineArray[Scalar[epilogue_dtype], frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]

Apply epilogue to both fragments (main entry point).

Returns:

Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]

apply_elementwise_epilogue_to_fragment​

apply_elementwise_epilogue_to_fragment[epilogue_dtype: DType, frag_size: Int, elementwise_lambda_fn: def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None](self, frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)

Apply elementwise epilogue lambda to fragment elements with global coords.

Unlike apply_to_fragment which uses a compute lambda that returns modified values, this calls an elementwise epilogue (returns None) that stores directly to global memory.

apply_elementwise_epilogue_to_both_fragments​

apply_elementwise_epilogue_to_both_fragments[epilogue_dtype: DType, frag_size: Int, elementwise_lambda_fn: def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None, is_lower_frag_required: Bool](self, upper_frag: SIMD[epilogue_dtype, frag_size], lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32)

Apply elementwise epilogue to both fragments.

Similar to apply_to_both_fragments but uses elementwise_epilogue_type which writes directly to global memory and returns None.

add_residual_to_fragment​

add_residual_to_fragment[epilogue_dtype: DType, frag_size: Int, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut frag: InlineArray[Scalar[epilogue_dtype], frag_size], local_row: UInt32, local_col: UInt32, is_upper: Bool, src_ptr: UnsafePointer[Scalar[c_type], address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype])

Add beta * C to fragment elements by loading C from swizzled SMEM.

Uses the same per-lane coordinate mapping as apply_to_fragment, but instead of applying a lambda, loads source C values from SMEM at the matching swizzled addresses and adds beta * C to each element.

Args:

add_residual_to_both_fragments​

add_residual_to_both_fragments[epilogue_dtype: DType, frag_size: Int, is_lower_frag_required: Bool, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut upper_frag: InlineArray[Scalar[epilogue_dtype], frag_size], mut lower_frag: InlineArray[Scalar[epilogue_dtype], frag_size], stage: UInt32, src_ptr: UnsafePointer[Scalar[c_type], address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype]) -> Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]

Add beta * C to both fragment halves from swizzled SMEM.

Computes tile-local coordinates from stage and warp ID, then loads source C from SMEM and adds beta * C to each fragment element.

Args:

Returns:

Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]: Updated (upper_frag, lower_frag) tuple.