IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

conv2d_fprop_with_residual

def conv2d_fprop_with_residual[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16(), elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, has_residual: Bool = False](output: TileTensor[out_type, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], activation: TileTensor[act_type, Storage=activation.Storage, address_space=activation.address_space, linear_idx_type=activation.linear_idx_type, element_size=activation.element_size], filter: TileTensor[filter_type, Storage=filter.Storage, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], source: TileTensor[out_type, Storage=source.Storage, address_space=source.address_space, linear_idx_type=source.linear_idx_type, element_size=source.element_size], beta: Float32, problem: Conv2dProblemShape, ctx: DeviceContext)

Launch Conv2D fprop with residual add.

Computes D = Conv(A,B) + beta*C. This function extends conv2d_fprop with residual add support. The epilogue load warp pre-fetches source tensor C via TMA, overlapping with MMA computation for better performance.

The residual add is applied after the optional epilogue lambda: D = lambda(Conv(A,B)) + beta * C

This supports common patterns like:

  • Skip connections: D = Conv(A,B) + C (beta=1.0)
  • Residual scaling: D = Conv(A,B) + 0.5*C (beta=0.5)
  • Fused residual+activation: D = ReLU(Conv(A,B)) + C

Note: The epilogue load warp (warp ID 7) handles C loading when residual is enabled. When has_residual is False or beta is 0, this warp exits early and the kernel behaves identically to conv2d_fprop.

Parameters:

Args:

Raises:

Error if kernel launch fails, constraints are violated, or source tensor shape doesn't match output shape.