Mojo module
matmul_kernel
Simdgroup-tiled Apple M5 matmul kernel built on MmaOpApple.
64x64 output tile per threadgroup; four simdgroups (128 threads) each own a 32x32 subtile (2x2 MmaOpApple). A per-simdgroup runtime branch picks between an unbounded fast path and a bounded path that handles ragged M/N edges and partial K tails.
Block-to-tile: each threadgroup decodes block_idx.x via
morton_decode_2d_rect over a side_m * side_n grid (each axis padded
to the next pow2). Threadgroups outside (grid_m, grid_n) early-return.
comptime valuesβ
BKβ
comptime BK = 16
BMβ
comptime BM = 64
BNβ
comptime BN = 64
NUM_SGβ
comptime NUM_SG = (NUM_SG_M * NUM_SG_N)
NUM_SG_Mβ
comptime NUM_SG_M = (BM / SG_M)
NUM_SG_Nβ
comptime NUM_SG_N = (BN / SG_N)
SG_Mβ
comptime SG_M = 32
SG_Nβ
comptime SG_N = 32
THREADS_PER_BLOCKβ
comptime THREADS_PER_BLOCK = (NUM_SG * Int[Int](WARP_SIZE))
Functionsβ
- β
apple_matmul_kernel: Apple M5 simdgroup-tiled GEMM: one 64x64 tile per threadgroup. - β
enqueue_apple_matmul: Enqueue the Apple M5 matmul kernel on the given device context. - β
morton_decode_2d: Decode a linear index to (tile_m, tile_n) via Morton Z-order. - β
morton_decode_2d_rect: Decodeflat_idxto (tile_m, tile_n) over a(1<<log2_m) x (1<<log2_n)grid.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!