Mojo module
conv2d_kernel
RDNA Conv2D implicit GEMM kernel with WMMA.
Fuses im2col into the RDNA WMMA matmul kernel's A-tile shared memory loader, eliminating the large intermediate im2col buffer. The B-tile (filter) is loaded normally from a pre-transposed [N, K] layout.
For common VAE decoder shapes (C_in=128/256/512), C_in is always a multiple of BLOCK_K, so consecutive K positions within a tile share the same (r,s) filter position. This enables vectorized 8-wide loads from the NHWC input β the same load width as the standard matmul kernel's A-tile loader.
comptime valuesβ
AB_FRAG_SIZEβ
comptime AB_FRAG_SIZE = 16
CD_FRAG_SIZEβ
comptime CD_FRAG_SIZE = 8
MMA_Kβ
comptime MMA_K = 16
MMA_Mβ
comptime MMA_M = 16
MMA_Nβ
comptime MMA_N = 16
SMEM_PADβ
comptime SMEM_PAD = 8
Functionsβ
- β
conv2d_kernel_rdna: Conv2D implicit GEMM kernel for RDNA 3+ GPUs.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!