Mojo struct
EpilogueApplier
struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]
Apply element-wise epilogue lambda to register fragments.
Fields
- coords (
EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords): - warp_id (
UInt32): - lane_id (
UInt32): - M (
UInt32): - N (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
Coords
comptime Coords = FragmentCoords[stageN, repeats]
Methods
__init__
__init__(warp_id: UInt32, lane_id: UInt32, c_shape: Tuple[UInt32, UInt32]) -> Self
compute_staged_coords
compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]
Compute global coords with warp and stage offsets (layout-dependent).
Returns:
apply_to_fragment
apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type](self, mut frag: InlineArray[Scalar[epilogue_dtype], frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)
Apply epilogue lambda to fragment elements with global coords.
apply_to_both_fragments
apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type, is_lower_frag_required: Bool](self, mut upper_frag: InlineArray[Scalar[epilogue_dtype], frag_size], mut lower_frag: InlineArray[Scalar[epilogue_dtype], frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]
Apply epilogue to both fragments (main entry point).
Returns:
apply_elementwise_epilogue_to_fragment
apply_elementwise_epilogue_to_fragment[epilogue_dtype: DType, frag_size: Int, elementwise_lambda_fn: elementwise_epilogue_type](self, frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)
Apply elementwise epilogue lambda to fragment elements with global coords.
Unlike apply_to_fragment which uses a compute lambda that returns modified values, this calls an elementwise epilogue (returns None) that stores directly to global memory.
apply_elementwise_epilogue_to_both_fragments
apply_elementwise_epilogue_to_both_fragments[epilogue_dtype: DType, frag_size: Int, elementwise_lambda_fn: elementwise_epilogue_type, is_lower_frag_required: Bool](self, upper_frag: SIMD[epilogue_dtype, frag_size], lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32)
Apply elementwise epilogue to both fragments.
Similar to apply_to_both_fragments but uses elementwise_epilogue_type which writes directly to global memory and returns None.
add_residual_to_fragment
add_residual_to_fragment[epilogue_dtype: DType, frag_size: Int, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut frag: InlineArray[Scalar[epilogue_dtype], frag_size], local_row: UInt32, local_col: UInt32, is_upper: Bool, src_ptr: UnsafePointer[Scalar[c_type], src_ptr.origin, address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype])
Add beta * C to fragment elements by loading C from swizzled SMEM.
Uses the same per-lane coordinate mapping as apply_to_fragment, but instead of applying a lambda, loads source C values from SMEM at the matching swizzled addresses and adds beta * C to each element.
Args:
- frag (
InlineArray): Fragment register values to modify in-place. - local_row (
UInt32): Tile-local row offset (warp offset within tile). - local_col (
UInt32): Tile-local column offset (stage offset within tile). - is_upper (
Bool): Whether this is the upper (rows 0-15) or lower (16-31) fragment half. - src_ptr (
UnsafePointer): Pointer to source C SMEM tile (same TMA swizzle as output). - beta (
Scalar): Residual scale factor.
add_residual_to_both_fragments
add_residual_to_both_fragments[epilogue_dtype: DType, frag_size: Int, is_lower_frag_required: Bool, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut upper_frag: InlineArray[Scalar[epilogue_dtype], frag_size], mut lower_frag: InlineArray[Scalar[epilogue_dtype], frag_size], stage: UInt32, src_ptr: UnsafePointer[Scalar[c_type], src_ptr.origin, address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype]) -> Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]
Add beta * C to both fragment halves from swizzled SMEM.
Computes tile-local coordinates from stage and warp ID, then loads source C from SMEM and adds beta * C to each fragment element.
Args:
- upper_frag (
InlineArray): Upper fragment (rows 0-15 within warp tile). - lower_frag (
InlineArray): Lower fragment (rows 16-31 within warp tile). - stage (
UInt32): Output stage index (for column offset computation). - src_ptr (
UnsafePointer): Pointer to source C SMEM tile. - beta (
Scalar): Residual scale factor.
Returns:
Tuple: Updated (upper_frag, lower_frag) tuple.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!