Mojo module
matmul
AMD RDNA matmul kernel with WMMA for RDNA 3+ and naive fallback for older.
RDNA 3+ (gfx11xx/gfx12xx): Uses 16x16x16 WMMA instructions with Wave32. RDNA 1/2 (gfx10xx): Falls back to a per-thread naive matmul.
Block configuration (shared by both paths): 4 warps x 32 threads = 128 threads per workgroup. BLOCK_M=64, BLOCK_N=64, BLOCK_K=16.
comptime values
AB_FRAG_SIZE
comptime AB_FRAG_SIZE = 16
BLOCK_K
comptime BLOCK_K = 16
BLOCK_M
comptime BLOCK_M = 64
BLOCK_N
comptime BLOCK_N = 64
CD_FRAG_SIZE
comptime CD_FRAG_SIZE = 8
MMA_K
comptime MMA_K = 16
MMA_M
comptime MMA_M = 16
MMA_N
comptime MMA_N = 16
NUM_C_TILES
comptime NUM_C_TILES = 4
NUM_M_MMAS
comptime NUM_M_MMAS = 2
NUM_N_MMAS
comptime NUM_N_MMAS = 2
NUM_WARPS
comptime NUM_WARPS = 4
WARP_M
comptime WARP_M = 32
WARP_N
comptime WARP_N = 32
WARPS_M
comptime WARPS_M = 2
WARPS_N
comptime WARPS_N = 2
Functions
-
gemm_kernel_rdna: GEMM kernel for AMD RDNA GPUs.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!