Skip to main content

Mojo function

conv2d_kernel_rdna

conv2d_kernel_rdna[out_type: DType, in_type: DType, filter_type: DType, out_layout: TensorLayout, filter_nk_layout: TensorLayout, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[out_type](), BLOCK_K: Int = 32, BLOCK_M: Int = 128, BLOCK_N: Int = 128, WARPS_M: Int = 8, WARPS_N: Int = 2, WARP_TILE_M: Int = 1, WARP_TILE_N: Int = 4](output: TileTensor[out_type, out_layout, MutAnyOrigin], input_ptr: UnsafePointer[Scalar[in_type], ImmutAnyOrigin], filter_nk: TileTensor[filter_type, filter_nk_layout, MutAnyOrigin], M: Int, N: Int, K: Int, HW_out: Int, W_out: Int, H_in: Int, W_in: Int, C_in: Int, R: Int, S: Int, pad_h: Int, pad_w: Int)

Conv2D implicit GEMM kernel for RDNA 3+ GPUs.

Identical to the RDNA WMMA matmul kernel except the A-tile is loaded directly from the NHWC input with on-the-fly im2col coordinate computation, eliminating the intermediate im2col buffer.

The B-tile (filter) must be pre-transposed to [N, K] = [C_out, RSC_in] layout for vectorized loading with transpose_b=True semantics.

Output [M, N] in row-major maps directly to NHWC output.

Constraints (enforced by the dispatch layer):

  • stride = (1, 1) and dilation = (1, 1)
  • K % BLOCK_K == 0 and C_in % BLOCK_K == 0

Was this page helpful?