Skip to main content

Mojo function

apple_matmul_kernel

apple_matmul_kernel[in_type: DType, transpose_b: Bool = False](d_ptr: UnsafePointer[Float32, MutAnyOrigin], a_ptr: UnsafePointer[Scalar[in_type], MutAnyOrigin], b_ptr: UnsafePointer[Scalar[in_type], MutAnyOrigin], m: Int, n: Int, k: Int, log2_grid_m: UInt32, log2_grid_n: UInt32)

Apple M5 simdgroup-tiled GEMM: one 64x64 tile per threadgroup.

Grid is (1<<log2_grid_m) * (1<<log2_grid_n) threadgroups of 128 threads; OOB threadgroups early-return after Morton decode. For transpose_b=True, B is the (N, K) row-major buffer reinterpreted as col_major(K, N).