Skip to main content

Mojo struct

EpilogueApplier

@register_passable(trivial) struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]

Apply element-wise epilogue lambda to register fragments.

Fields

  • coords (EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords):
  • warp_id (UInt32):
  • lane_id (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Coords

comptime Coords = FragmentCoords[stageN, repeats]

Methods

__init__

__init__(warp_id: UInt32, lane_id: UInt32) -> Self

compute_staged_coords

compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]

Compute global coords with warp and stage offsets (layout-dependent).

Returns:

Tuple

apply_to_fragment

apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type](self, mut frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)

Apply epilogue lambda to fragment elements with global coords.

apply_to_both_fragments

apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type, is_lower_frag_required: Bool](self, mut upper_frag: SIMD[epilogue_dtype, frag_size], mut lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[SIMD[epilogue_dtype, frag_size], SIMD[epilogue_dtype, frag_size]]

Apply epilogue to both fragments (main entry point).

Returns:

Tuple

add_residual_to_fragment

add_residual_to_fragment[epilogue_dtype: DType, frag_size: Int, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut frag: SIMD[epilogue_dtype, frag_size], local_row: UInt32, local_col: UInt32, is_upper: Bool, src_ptr: UnsafePointer[Scalar[c_type], origin, address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype])

Add beta * C to fragment elements by loading C from swizzled SMEM.

Uses the same per-lane coordinate mapping as apply_to_fragment, but instead of applying a lambda, loads source C values from SMEM at the matching swizzled addresses and adds beta * C to each element.

Args:

  • frag (SIMD): Fragment register values to modify in-place.
  • local_row (UInt32): Tile-local row offset (warp offset within tile).
  • local_col (UInt32): Tile-local column offset (stage offset within tile).
  • is_upper (Bool): Whether this is the upper (rows 0-15) or lower (16-31) fragment half.
  • src_ptr (UnsafePointer): Pointer to source C SMEM tile (same TMA swizzle as output).
  • beta (Scalar): Residual scale factor.

add_residual_to_both_fragments

add_residual_to_both_fragments[epilogue_dtype: DType, frag_size: Int, is_lower_frag_required: Bool, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut upper_frag: SIMD[epilogue_dtype, frag_size], mut lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, src_ptr: UnsafePointer[Scalar[c_type], origin, address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype]) -> Tuple[SIMD[epilogue_dtype, frag_size], SIMD[epilogue_dtype, frag_size]]

Add beta * C to both fragment halves from swizzled SMEM.

Computes tile-local coordinates from stage and warp ID, then loads source C from SMEM and adds beta * C to each fragment element.

Args:

  • upper_frag (SIMD): Upper fragment (rows 0-15 within warp tile).
  • lower_frag (SIMD): Lower fragment (rows 16-31 within warp tile).
  • stage (UInt32): Output stage index (for column offset computation).
  • src_ptr (UnsafePointer): Pointer to source C SMEM tile.
  • beta (Scalar): Residual scale factor.

Returns:

Tuple: Updated (upper_frag, lower_frag) tuple.

Was this page helpful?