Skip to main content

Mojo function

conv2d_fprop_with_residual

conv2d_fprop_with_residual[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16[act_type, filter_type, out_type](), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, has_residual: Bool = False](output: NDBuffer[out_type, 4, origin], activation: NDBuffer[act_type, 4, origin], filter: NDBuffer[filter_type, 4, origin], source: NDBuffer[out_type, 4, origin], beta: Float32, problem: Conv2dProblemShape, ctx: DeviceContext)

Launch Conv2D fprop with residual add.

Computes D = Conv(A,B) + beta*C. This function extends conv2d_fprop with residual add support. The epilogue load warp pre-fetches source tensor C via TMA, overlapping with MMA computation for better performance.

The residual add is applied after the optional epilogue lambda: D = lambda(Conv(A,B)) + beta * C

This supports common patterns like:

  • Skip connections: D = Conv(A,B) + C (beta=1.0)
  • Residual scaling: D = Conv(A,B) + 0.5*C (beta=0.5)
  • Fused residual+activation: D = ReLU(Conv(A,B)) + C

Note: The epilogue load warp (warp ID 7) handles C loading when residual is enabled. When has_residual is False or beta is 0, this warp exits early and the kernel behaves identically to conv2d_fprop.

Parameters:

  • act_type (DType): Data type of the input activation tensor.
  • filter_type (DType): Data type of the filter weights tensor.
  • out_type (DType): Data type of the output tensor.
  • config (Conv2dConfig): Kernel configuration (tile sizes, pipeline stages, etc.).
  • elementwise_compute_lambda_fn (Optional): Optional element-wise lambda function for epilogue fusion (bias add, activation). Applied before residual.
  • register_based_epilogue (Bool): If True, apply lambda in registers (faster).
  • has_residual (Bool): If True, apply residual add. If False, source is ignored.

Args:

  • output (NDBuffer): Output tensor [N, H_out, W_out, C_out] in NHWC layout (D).
  • activation (NDBuffer): Input activation [N, H, W, C] in NHWC layout (A).
  • filter (NDBuffer): Filter weights [K, R, S, C] in KRSC layout (B).
  • source (NDBuffer): Source tensor [N, H_out, W_out, C_out] for residual (C).
  • beta (Float32): Residual scale factor. If 0.0, no residual is applied.
  • problem (Conv2dProblemShape): Convolution problem shape specification.
  • ctx (DeviceContext): Device context for kernel launch.

Raises:

Error if kernel launch fails, constraints are violated, or source tensor shape doesn't match output shape.

Was this page helpful?