Mojo module
mxfp4_matmul_amd
MXFP4 matmul on AMD CDNA GPUs via dequant-to-FP8 + FP8 GEMM.
Dequantizes MXFP4 weights to FP8, casts BF16 activations to FP8, then dispatches to the AMD FP8 GEMM via _matmul_gpu.
MI355X (CDNA4) uses float8_e4m3fn; MI300X (CDNA3) uses float8_e4m3fnuz. The FP8 type is selected at compile time based on the target architecture.
Functions
-
mxfp4_matmul_amd: MXFP4 matmul: dequant B weights to FP8, cast A to FP8, AMD FP8 GEMM.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!