Skip to main content

Mojo module

mxfp4_grouped_matmul_amd

MXFP4 grouped matmul on AMD CDNA GPUs via dequant-to-FP8 + FP8 grouped GEMM.

Dequantizes MXFP4 expert weights to FP8, casts BF16 activations to FP8, then dispatches to the AMD FP8 grouped GEMM via grouped_matmul.

The grouped matmul computes: for i in range(num_active_experts): C[offsets[i]:offsets[i+1], :] = A[offsets[i]:offsets[i+1], :] @ B[expert_ids[i], :, :].T

where B weights are stored as packed MXFP4 (uint8) with E8M0 scales, and A activations are BF16. Both are dequantized/cast to FP8 before the GEMM.

Functions

Was this page helpful?