Skip to main content

Mojo module

matmul

AMD RDNA matmul kernel with WMMA for RDNA 3+ and naive fallback for older.

RDNA 3+ (gfx11xx/gfx12xx): Uses 16x16x16 WMMA instructions with Wave32. RDNA 1/2 (gfx10xx): Falls back to a per-thread naive matmul.

Block configuration (shared by both paths): 4 warps x 32 threads = 128 threads per workgroup. BLOCK_M=64, BLOCK_N=64, BLOCK_K=16.

comptime values

AB_FRAG_SIZE

comptime AB_FRAG_SIZE = 16

BLOCK_K

comptime BLOCK_K = 16

BLOCK_M

comptime BLOCK_M = 64

BLOCK_N

comptime BLOCK_N = 64

CD_FRAG_SIZE

comptime CD_FRAG_SIZE = 8

MMA_K

comptime MMA_K = 16

MMA_M

comptime MMA_M = 16

MMA_N

comptime MMA_N = 16

NUM_C_TILES

comptime NUM_C_TILES = 4

NUM_M_MMAS

comptime NUM_M_MMAS = 2

NUM_N_MMAS

comptime NUM_N_MMAS = 2

NUM_WARPS

comptime NUM_WARPS = 4

WARP_M

comptime WARP_M = 32

WARP_N

comptime WARP_N = 32

WARPS_M

comptime WARPS_M = 2

WARPS_N

comptime WARPS_N = 2

Functions

Was this page helpful?