Mojo module
conv2d_kernel
RDNA Conv2D implicit GEMM kernel with WMMA.
Fuses im2col into the RDNA WMMA matmul kernel's A-tile shared memory loader, eliminating the large intermediate im2col buffer. The B-tile (filter) is loaded normally from a pre-transposed [N, K] layout.
For common VAE decoder shapes (C_in=128/256/512), C_in is always a multiple of BLOCK_K, so consecutive K positions within a tile share the same (r,s) filter position. This enables vectorized 8-wide loads from the NHWC input — the same load width as the standard matmul kernel's A-tile loader.
comptime values
AB_FRAG_SIZE
comptime AB_FRAG_SIZE = 16
CD_FRAG_SIZE
comptime CD_FRAG_SIZE = 8
MMA_K
comptime MMA_K = 16
MMA_M
comptime MMA_M = 16
MMA_N
comptime MMA_N = 16
SMEM_PAD
comptime SMEM_PAD = 8
Functions
-
conv2d_kernel_rdna: Conv2D implicit GEMM kernel for RDNA 3+ GPUs.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!