Mojo function
conv2d_fprop_with_residual
conv2d_fprop_with_residual[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16(), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, has_residual: Bool = False](output: TileTensor[out_type, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], activation: TileTensor[act_type, address_space=activation.address_space, linear_idx_type=activation.linear_idx_type, element_size=activation.element_size], filter: TileTensor[filter_type, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], source: TileTensor[out_type, address_space=source.address_space, linear_idx_type=source.linear_idx_type, element_size=source.element_size], beta: Float32, problem: Conv2dProblemShape, ctx: DeviceContext)
Launch Conv2D fprop with residual add.
Computes D = Conv(A,B) + beta*C. This function extends conv2d_fprop with residual add support. The epilogue load warp pre-fetches source tensor C via TMA, overlapping with MMA computation for better performance.
The residual add is applied after the optional epilogue lambda: D = lambda(Conv(A,B)) + beta * C
This supports common patterns like:
- Skip connections: D = Conv(A,B) + C (beta=1.0)
- Residual scaling: D = Conv(A,B) + 0.5*C (beta=0.5)
- Fused residual+activation: D = ReLU(Conv(A,B)) + C
Note: The epilogue load warp (warp ID 7) handles C loading when residual is enabled. When has_residual is False or beta is 0, this warp exits early and the kernel behaves identically to conv2d_fprop.
Parameters:
- βact_type (
DType): Data type of the input activation tensor. - βfilter_type (
DType): Data type of the filter weights tensor. - βout_type (
DType): Data type of the output tensor. - βconfig (
Conv2dConfig[act_type, filter_type, out_type]): Kernel configuration (tile sizes, pipeline stages, etc.). - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional void epilogue lambda applied after output write. Signature:def(IndexList[2], SIMD) -> None. - βelementwise_compute_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]): Optional element-wise lambda function for epilogue fusion (bias add, activation). Applied before residual. - βregister_based_epilogue (
Bool): If True, apply lambda in registers (faster). - βhas_residual (
Bool): If True, apply residual add. If False, source is ignored.
Args:
- βoutput (
TileTensor[out_type, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tensor [N, H_out, W_out, C_out] in NHWC layout (D). - βactivation (
TileTensor[act_type, address_space=activation.address_space, linear_idx_type=activation.linear_idx_type, element_size=activation.element_size]): Input activation [N, H, W, C] in NHWC layout (A). - βfilter (
TileTensor[filter_type, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size]): Filter weights [K, R, S, C] in KRSC layout (B). - βsource (
TileTensor[out_type, address_space=source.address_space, linear_idx_type=source.linear_idx_type, element_size=source.element_size]): Source tensor [N, H_out, W_out, C_out] for residual (C). - βbeta (
Float32): Residual scale factor. If 0.0, no residual is applied. - βproblem (
Conv2dProblemShape): Convolution problem shape specification. - βctx (
DeviceContext): Device context for kernel launch.
Raises:
Error if kernel launch fails, constraints are violated, or source tensor shape doesn't match output shape.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!