IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

conv2d_fprop

def conv2d_fprop[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type] = Conv2dConfig.default_bf16(), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True](output: TileTensor[out_type, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], activation: TileTensor[act_type, address_space=activation.address_space, linear_idx_type=activation.linear_idx_type, element_size=activation.element_size], filter: TileTensor[filter_type, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], problem: Conv2dProblemShape, ctx: DeviceContext)

Launch Conv2D forward propagation with 4D NHWC API and implicit im2col.

This function provides a 4D tensor API for conv2d forward propagation using hardware TMA im2col transformation. The TMA descriptor encodes the convolution geometry (padding, stride, dilation) and performs coordinate transformation on-the-fly, eliminating the need for explicit im2col buffers.

The convolution is implemented as implicit GEMM:

  • Activation matrix A[M, K] where M = batchH_outW_out, K = CRS
  • Filter matrix B[K, N] where N = out_channels (transposed)
  • Output matrix C[M, N]

The TMA im2col descriptor handles the linear K iteration by decomposing k_coord into (channel, filter_r, filter_s) using the corner parameters:

  • lower_corner defines the starting filter offset (negative for padding)
  • upper_corner defines the ending filter offset
  • channels_per_pixel is the number of input channels (C)
  • pixels_per_column is the output spatial tile size (BM)

Parameters:

Args:

Raises:

Error if kernel launch fails or constraints are violated.