Mojo function
conv2d_kernel_rdna
conv2d_kernel_rdna[out_type: DType, in_type: DType, filter_type: DType, out_layout: TensorLayout, filter_nk_layout: TensorLayout, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, s_type: DType = get_accum_type[out_type](), BLOCK_K: Int = 32, BLOCK_M: Int = 128, BLOCK_N: Int = 128, WARPS_M: Int = 8, WARPS_N: Int = 2, WARP_TILE_M: Int = 1, WARP_TILE_N: Int = 4](output: TileTensor[out_type, out_layout, MutAnyOrigin], input_ptr: UnsafePointer[Scalar[in_type], ImmutAnyOrigin], filter_nk: TileTensor[filter_type, filter_nk_layout, MutAnyOrigin], M: Int, N: Int, K: Int, HW_out: Int, W_out: Int, H_in: Int, W_in: Int, C_in: Int, R: Int, S: Int, pad_h: Int, pad_w: Int)
Conv2D implicit GEMM kernel for RDNA 3+ GPUs.
Identical to the RDNA WMMA matmul kernel except the A-tile is loaded directly from the NHWC input with on-the-fly im2col coordinate computation, eliminating the intermediate im2col buffer.
The B-tile (filter) must be pre-transposed to [N, K] = [C_out, RSC_in] layout for vectorized loading with transpose_b=True semantics.
Output [M, N] in row-major maps directly to NHWC output.
Constraints (enforced by the dispatch layer):
- stride = (1, 1) and dilation = (1, 1)
- K % BLOCK_K == 0 and C_in % BLOCK_K == 0
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!