Skip to main content

Mojo function

conv2d_fprop

conv2d_fprop[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16[act_type, filter_type, out_type](), elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True](output: NDBuffer[out_type, 4, origin], activation: NDBuffer[act_type, 4, origin], filter: NDBuffer[filter_type, 4, origin], problem: Conv2dProblemShape, ctx: DeviceContext)

Launch Conv2D forward propagation with 4D NHWC API and implicit im2col.

This function provides a 4D tensor API for conv2d forward propagation using hardware TMA im2col transformation. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly, eliminating the need for explicit im2col buffers.

The convolution is implemented as implicit GEMM:

  • Activation matrix A[M, K] where M = batchH_outW_out, K = CRS
  • Filter matrix B[K, N] where N = out_channels (transposed)
  • Output matrix C[M, N]

The TMA im2col descriptor handles the linear K iteration by decomposing k_coord into (channel, filter_r, filter_s) using the corner parameters:

  • lower_corner defines the starting filter offset (negative for padding)
  • upper_corner defines the ending filter offset
  • channels_per_pixel is the number of input channels (C)
  • pixels_per_column is the output spatial tile size (BM)

Parameters:

  • act_type (DType): Data type of the input activation tensor.
  • filter_type (DType): Data type of the filter weights tensor.
  • out_type (DType): Data type of the output tensor.
  • config (Conv2dConfig): Kernel configuration (tile sizes, pipeline stages, etc.).
  • elementwise_compute_lambda_fn (Optional): Optional element-wise lambda function for epilogue fusion (bias add, activation, residual connection). Signature: fn(coords: IndexList[2], val: SIMD) -> SIMD.
  • register_based_epilogue (Bool): If True, apply lambda in registers (faster). If False, apply lambda after SMEM write (more flexible).

Args:

  • output (NDBuffer): Output tensor [N, H_out, W_out, C_out] in NHWC layout.
  • activation (NDBuffer): Input activation [N, H, W, C] in NHWC layout.
  • filter (NDBuffer): Filter weights [K, R, S, C] in KRSC layout.
  • problem (Conv2dProblemShape): Convolution problem shape specification.
  • ctx (DeviceContext): Device context for kernel launch.

Raises:

Error if kernel launch fails or constraints are violated.

Was this page helpful?