Skip to main content

Mojo module

conv2d_kernel

RDNA Conv2D implicit GEMM kernel with WMMA.

Fuses im2col into the RDNA WMMA matmul kernel's A-tile shared memory loader, eliminating the large intermediate im2col buffer. The B-tile (filter) is loaded normally from a pre-transposed [N, K] layout.

For common VAE decoder shapes (C_in=128/256/512), C_in is always a multiple of BLOCK_K, so consecutive K positions within a tile share the same (r,s) filter position. This enables vectorized 8-wide loads from the NHWC input — the same load width as the standard matmul kernel's A-tile loader.

comptime values

AB_FRAG_SIZE

comptime AB_FRAG_SIZE = 16

CD_FRAG_SIZE

comptime CD_FRAG_SIZE = 8

MMA_K

comptime MMA_K = 16

MMA_M

comptime MMA_M = 16

MMA_N

comptime MMA_N = 16

SMEM_PAD

comptime SMEM_PAD = 8

Functions

Was this page helpful?