Mojo module
mxfp4_grouped_matmul_amd
MXFP4 grouped matmul on AMD CDNA GPUs via dequant-to-FP8 + FP8 grouped GEMM.
Dequantizes MXFP4 expert weights to FP8, casts BF16 activations to FP8, then dispatches to the AMD FP8 grouped GEMM via grouped_matmul.
The grouped matmul computes: for i in range(num_active_experts): C[offsets[i]:offsets[i+1], :] = A[offsets[i]:offsets[i+1], :] @ B[expert_ids[i], :, :].T
where B weights are stored as packed MXFP4 (uint8) with E8M0 scales, and A activations are BF16. Both are dequantized/cast to FP8 before the GEMM.
Functions
-
mxfp4_grouped_matmul_amd: MXFP4 grouped matmul: dequant B weights, cast A to FP8, FP8 grouped GEMM.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!