Skip to main content

Mojo function

conv2d_fprop

conv2d_fprop[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16(), elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True](output: TileTensor[out_type, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], activation: TileTensor[act_type, activation.LayoutType, activation.origin, address_space=activation.address_space, linear_idx_type=activation.linear_idx_type, element_size=activation.element_size], filter: TileTensor[filter_type, filter.LayoutType, filter.origin, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], problem: Conv2dProblemShape, ctx: DeviceContext)

Launch Conv2D forward propagation with 4D NHWC API and implicit im2col.

This function provides a 4D tensor API for conv2d forward propagation using hardware TMA im2col transformation. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly, eliminating the need for explicit im2col buffers.

The convolution is implemented as implicit GEMM:

  • Activation matrix A[M, K] where M = batchH_outW_out, K = CRS
  • Filter matrix B[K, N] where N = out_channels (transposed)
  • Output matrix C[M, N]

The TMA im2col descriptor handles the linear K iteration by decomposing k_coord into (channel, filter_r, filter_s) using the corner parameters:

  • lower_corner defines the starting filter offset (negative for padding)
  • upper_corner defines the ending filter offset
  • channels_per_pixel is the number of input channels (C)
  • pixels_per_column is the output spatial tile size (BM)

Parameters:

  • act_type (DType): Data type of the input activation tensor.
  • filter_type (DType): Data type of the filter weights tensor.
  • out_type (DType): Data type of the output tensor.
  • config (Conv2dConfig): Kernel configuration (tile sizes, pipeline stages, etc.).
  • elementwise_lambda_fn (Optional): Optional void epilogue lambda applied after output write. Signature: def(IndexList[2], SIMD) -> None.
  • elementwise_compute_lambda_fn (Optional): Optional element-wise lambda function for epilogue fusion (bias add, activation, residual connection). Signature: def(coords: IndexList[2], val: SIMD) -> SIMD.
  • register_based_epilogue (Bool): If True, apply lambda in registers (faster). If False, apply lambda after SMEM write (more flexible).

Args:

  • output (TileTensor): Output tensor [N, H_out, W_out, C_out] in NHWC layout.
  • activation (TileTensor): Input activation [N, H, W, C] in NHWC layout.
  • filter (TileTensor): Filter weights [K, R, S, C] in KRSC layout.
  • problem (Conv2dProblemShape): Convolution problem shape specification.
  • ctx (DeviceContext): Device context for kernel launch.

Raises:

Error if kernel launch fails or constraints are violated.

Was this page helpful?