Skip to main content

Mojo function

conv2d_fprop_with_residual

conv2d_fprop_with_residual[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16(), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, has_residual: Bool = False](output: TileTensor[out_type, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], activation: TileTensor[act_type, address_space=activation.address_space, linear_idx_type=activation.linear_idx_type, element_size=activation.element_size], filter: TileTensor[filter_type, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], source: TileTensor[out_type, address_space=source.address_space, linear_idx_type=source.linear_idx_type, element_size=source.element_size], beta: Float32, problem: Conv2dProblemShape, ctx: DeviceContext)

Launch Conv2D fprop with residual add.

Computes D = Conv(A,B) + beta*C. This function extends conv2d_fprop with residual add support. The epilogue load warp pre-fetches source tensor C via TMA, overlapping with MMA computation for better performance.

The residual add is applied after the optional epilogue lambda: D = lambda(Conv(A,B)) + beta * C

This supports common patterns like:

  • Skip connections: D = Conv(A,B) + C (beta=1.0)
  • Residual scaling: D = Conv(A,B) + 0.5*C (beta=0.5)
  • Fused residual+activation: D = ReLU(Conv(A,B)) + C

Note: The epilogue load warp (warp ID 7) handles C loading when residual is enabled. When has_residual is False or beta is 0, this warp exits early and the kernel behaves identically to conv2d_fprop.

Parameters:

Args:

Raises:

Error if kernel launch fails, constraints are violated, or source tensor shape doesn't match output shape.