Mojo function
conv2d_fprop_with_residual
conv2d_fprop_with_residual[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16[act_type, filter_type, out_type](), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, has_residual: Bool = False](output: NDBuffer[out_type, 4, origin], activation: NDBuffer[act_type, 4, origin], filter: NDBuffer[filter_type, 4, origin], source: NDBuffer[out_type, 4, origin], beta: Float32, problem: Conv2dProblemShape, ctx: DeviceContext)
Launch Conv2D fprop with residual add.
Computes D = Conv(A,B) + beta*C. This function extends conv2d_fprop with residual add support. The epilogue load warp pre-fetches source tensor C via TMA, overlapping with MMA computation for better performance.
The residual add is applied after the optional epilogue lambda: D = lambda(Conv(A,B)) + beta * C
This supports common patterns like:
- Skip connections: D = Conv(A,B) + C (beta=1.0)
- Residual scaling: D = Conv(A,B) + 0.5*C (beta=0.5)
- Fused residual+activation: D = ReLU(Conv(A,B)) + C
Note: The epilogue load warp (warp ID 7) handles C loading when residual is enabled. When has_residual is False or beta is 0, this warp exits early and the kernel behaves identically to conv2d_fprop.
Parameters:
- act_type (
DType): Data type of the input activation tensor. - filter_type (
DType): Data type of the filter weights tensor. - out_type (
DType): Data type of the output tensor. - config (
Conv2dConfig): Kernel configuration (tile sizes, pipeline stages, etc.). - elementwise_compute_lambda_fn (
Optional): Optional element-wise lambda function for epilogue fusion (bias add, activation). Applied before residual. - register_based_epilogue (
Bool): If True, apply lambda in registers (faster). - has_residual (
Bool): If True, apply residual add. If False, source is ignored.
Args:
- output (
NDBuffer): Output tensor [N, H_out, W_out, C_out] in NHWC layout (D). - activation (
NDBuffer): Input activation [N, H, W, C] in NHWC layout (A). - filter (
NDBuffer): Filter weights [K, R, S, C] in KRSC layout (B). - source (
NDBuffer): Source tensor [N, H_out, W_out, C_out] for residual (C). - beta (
Float32): Residual scale factor. If 0.0, no residual is applied. - problem (
Conv2dProblemShape): Convolution problem shape specification. - ctx (
DeviceContext): Device context for kernel launch.
Raises:
Error if kernel launch fails, constraints are violated, or source tensor shape doesn't match output shape.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!