Mojo function
conv2d_fprop
conv2d_fprop[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16[act_type, filter_type, out_type](), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True](output: NDBuffer[out_type, 4, origin], activation: NDBuffer[act_type, 4, origin], filter: NDBuffer[filter_type, 4, origin], problem: Conv2dProblemShape, ctx: DeviceContext)
Launch Conv2D forward propagation with 4D NHWC API and implicit im2col.
This function provides a 4D tensor API for conv2d forward propagation using hardware TMA im2col transformation. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly, eliminating the need for explicit im2col buffers.
The convolution is implemented as implicit GEMM:
- Activation matrix A[M, K] where M = batchH_outW_out, K = CRS
- Filter matrix B[K, N] where N = out_channels (transposed)
- Output matrix C[M, N]
The TMA im2col descriptor handles the linear K iteration by decomposing k_coord into (channel, filter_r, filter_s) using the corner parameters:
- lower_corner defines the starting filter offset (negative for padding)
- upper_corner defines the ending filter offset
- channels_per_pixel is the number of input channels (C)
- pixels_per_column is the output spatial tile size (BM)
Parameters:
- act_type (
DType): Data type of the input activation tensor. - filter_type (
DType): Data type of the filter weights tensor. - out_type (
DType): Data type of the output tensor. - config (
Conv2dConfig): Kernel configuration (tile sizes, pipeline stages, etc.). - elementwise_compute_lambda_fn (
Optional): Optional element-wise lambda function for epilogue fusion (bias add, activation, residual connection). Signature:fn(coords: IndexList[2], val: SIMD) -> SIMD. - register_based_epilogue (
Bool): If True, apply lambda in registers (faster). If False, apply lambda after SMEM write (more flexible).
Args:
- output (
NDBuffer): Output tensor [N, H_out, W_out, C_out] in NHWC layout. - activation (
NDBuffer): Input activation [N, H, W, C] in NHWC layout. - filter (
NDBuffer): Filter weights [K, R, S, C] in KRSC layout. - problem (
Conv2dProblemShape): Convolution problem shape specification. - ctx (
DeviceContext): Device context for kernel launch.
Raises:
Error if kernel launch fails or constraints are violated.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!