Mojo function
apple_matmul_kernel
apple_matmul_kernel[in_type: DType, transpose_b: Bool = False](d_ptr: UnsafePointer[Float32, MutAnyOrigin], a_ptr: UnsafePointer[Scalar[in_type], MutAnyOrigin], b_ptr: UnsafePointer[Scalar[in_type], MutAnyOrigin], m: Int, n: Int, k: Int, log2_grid_m: UInt32, log2_grid_n: UInt32)
Apple M5 simdgroup-tiled GEMM: one 64x64 tile per threadgroup.
Grid is (1<<log2_grid_m) * (1<<log2_grid_n) threadgroups of 128 threads;
OOB threadgroups early-return after Morton decode. For transpose_b=True,
B is the (N, K) row-major buffer reinterpreted as col_major(K, N).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!