IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

conv2d_kernel_rdna

def conv2d_kernel_rdna[out_type: DType, in_type: DType, filter_type: DType, out_layout: TensorLayout, filter_nk_layout: TensorLayout, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[out_type](), BLOCK_K: Int = 32, BLOCK_M: Int = 128, BLOCK_N: Int = 128, WARPS_M: Int = 8, WARPS_N: Int = 2, WARP_TILE_M: Int = 1, WARP_TILE_N: Int = 4](output: TileTensor[out_type, out_layout, MutAnyOrigin], input_ptr: UnsafePointer[Scalar[in_type], ImmutAnyOrigin], filter_nk: TileTensor[filter_type, filter_nk_layout, MutAnyOrigin], M: Int, N: Int, K: Int, HW_out: Int, W_out: Int, H_in: Int, W_in: Int, C_in: Int, R: Int, S: Int, pad_h: Int, pad_w: Int)

Conv2D implicit GEMM kernel for RDNA 3+ GPUs.

Identical to the RDNA WMMA matmul kernel except the A-tile is loaded directly from the NHWC input with on-the-fly im2col coordinate computation, eliminating the intermediate im2col buffer.

The B-tile (filter) must be pre-transposed to [N, K] = [C_out, RSC_in] layout for vectorized loading with transpose_b=True semantics.

Output [M, N] in row-major maps directly to NHWC output.

Constraints (enforced by the dispatch layer):

  • stride = (1, 1) and dilation = (1, 1)
  • K % BLOCK_K == 0 and C_in % BLOCK_K == 0