IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

EpilogueApplier

struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]

Apply element-wise epilogue lambda to register fragments.

Fields​

  • ​coords (EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords):
  • ​warp_id (UInt32):
  • ​lane_id (UInt32):
  • ​M (UInt32):
  • ​N (UInt32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

Coords​

comptime Coords = FragmentCoords[stageN, repeats]

Methods​

__init__​

def __init__(warp_id: UInt32, lane_id: UInt32, c_shape: Tuple[UInt32, UInt32]) -> Self

compute_staged_coords​

def compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]

Compute global coords with warp and stage offsets (layout-dependent).

Returns:

Tuple[UInt32, UInt32]

apply_to_fragment​

def apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width], is_in_bounds: Bool = False](self, mut frag: InlineArray[Scalar[epilogue_dtype], frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)

Apply epilogue lambda to fragment elements with global coords.

is_in_bounds=True: caller asserts the whole tile fits in (self.M, self.N); per-position checks are elided. Default False keeps them β€” TMA masks the OOB write but the lambda may dereference operands at the supplied (row, col).

transpose_c=True: per-position β€” top.row/bot.row in a (top, bot) pair are 8 apart (upper +0/+8 or lower +16/+24) and can straddle self.N. transpose_c=False: early return on top_col is safe (top.col == bot.col and self.N is alignment-bound).

apply_to_both_fragments​

def apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width], is_lower_frag_required: Bool, is_in_bounds: Bool = False](self, mut upper_frag: InlineArray[Scalar[epilogue_dtype], frag_size], mut lower_frag: InlineArray[Scalar[epilogue_dtype], frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]

Apply epilogue to both fragments (main entry point).

Returns:

Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]

apply_elementwise_epilogue_to_fragment​

def apply_elementwise_epilogue_to_fragment[epilogue_dtype: DType, frag_size: Int, elementwise_lambda_fn: def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None](self, frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)

Apply elementwise epilogue lambda to fragment elements with global coords.

Unlike apply_to_fragment which uses a compute lambda that returns modified values, this calls an elementwise epilogue (returns None) that stores directly to global memory.

apply_elementwise_epilogue_to_both_fragments​

def apply_elementwise_epilogue_to_both_fragments[epilogue_dtype: DType, frag_size: Int, elementwise_lambda_fn: def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None, is_lower_frag_required: Bool](self, upper_frag: SIMD[epilogue_dtype, frag_size], lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32)

Apply elementwise epilogue to both fragments.

Similar to apply_to_both_fragments but uses elementwise_epilogue_type which writes directly to global memory and returns None.

add_residual_to_fragment​

def add_residual_to_fragment[epilogue_dtype: DType, frag_size: Int, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut frag: InlineArray[Scalar[epilogue_dtype], frag_size], local_row: UInt32, local_col: UInt32, is_upper: Bool, src_ptr: UnsafePointer[Scalar[c_type], address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype])

Add beta * C to fragment elements by loading C from swizzled SMEM.

Uses the same per-lane coordinate mapping as apply_to_fragment, but instead of applying a lambda, loads source C values from SMEM at the matching swizzled addresses and adds beta * C to each element.

Args:

add_residual_to_both_fragments​

def add_residual_to_both_fragments[epilogue_dtype: DType, frag_size: Int, is_lower_frag_required: Bool, c_type: DType, c_smem_stride: Int, swizzle: Swizzle](self, mut upper_frag: InlineArray[Scalar[epilogue_dtype], frag_size], mut lower_frag: InlineArray[Scalar[epilogue_dtype], frag_size], stage: UInt32, src_ptr: UnsafePointer[Scalar[c_type], address_space=AddressSpace.SHARED], beta: Scalar[epilogue_dtype]) -> Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]

Add beta * C to both fragment halves from swizzled SMEM.

Computes tile-local coordinates from stage and warp ID, then loads source C from SMEM and adds beta * C to each fragment element.

Args:

Returns:

Tuple[InlineArray[Scalar[epilogue_dtype], frag_size], InlineArray[Scalar[epilogue_dtype], frag_size]]: Updated (upper_frag, lower_frag) tuple.