Skip to main content

Mojo function

conv2d_fprop

conv2d_fprop[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16(), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True](output: TileTensor[out_type, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], activation: TileTensor[act_type, address_space=activation.address_space, linear_idx_type=activation.linear_idx_type, element_size=activation.element_size], filter: TileTensor[filter_type, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], problem: Conv2dProblemShape, ctx: DeviceContext)

Launch Conv2D forward propagation with 4D NHWC API and implicit im2col.

This function provides a 4D tensor API for conv2d forward propagation using hardware TMA im2col transformation. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly, eliminating the need for explicit im2col buffers.

The convolution is implemented as implicit GEMM:

  • Activation matrix A[M, K] where M = batchH_outW_out, K = CRS
  • Filter matrix B[K, N] where N = out_channels (transposed)
  • Output matrix C[M, N]

The TMA im2col descriptor handles the linear K iteration by decomposing k_coord into (channel, filter_r, filter_s) using the corner parameters:

  • lower_corner defines the starting filter offset (negative for padding)
  • upper_corner defines the ending filter offset
  • channels_per_pixel is the number of input channels (C)
  • pixels_per_column is the output spatial tile size (BM)

Parameters:

Args:

Raises:

Error if kernel launch fails or constraints are violated.