Mojo module
mxfp4_matmul_amd
Native MXFP4 block-scaled matmul on AMD CDNA4 via f8f6f4 MFMA.
Computes C = (A * scale_a) @ (B * scale_b)^T where A and B are packed MXFP4 (E2M1) in uint8 with per-block E8M0 scaling factors. Uses the CDNA4 mfma.scale.f32.16x16x128.f8f6f4 instruction which natively consumes MXFP4 operands with E8M0 scale words — no dequantization needed.
Structure mirrors AMDMatmul: TileTensor throughout, RegTileLoader for DRAM→regs, row-major SMEM (no blocked-product or swizzle — the FP4 MFMA expects a simple row-major lane-to-data mapping unlike BF16/FP8), schedule-driven pipeline.
MXFP4 data layout: A: [M, K//2] uint8 (two MXFP4 nibbles packed per byte), row-major B: [N, K//2] uint8, row-major (transposed: each row is one output column) scale_a: [M, K//32] float8_e8m0fnu (one scale per 32 MXFP4 elements) scale_b: [N, K//32] float8_e8m0fnu
MFMA lane-to-data mapping for 16x16x128 FP4: Each lane loads 16 contiguous bytes from its assigned matrix row. lane_row = lane_id % MMA_M, lane_chunk = lane_id / MMA_M. Offset = lane_row * row_stride + lane_chunk * 16. The 16 bytes are zero-extended to SIMD[uint8, 32] for the MFMA operand.
MFMA scale model (16x16x128): Each lane holds 16x128/64 = 32 FP4 elements and one E8M0 scale. This matches the MX format exactly: one scale per 32 elements. The 64 scale values (16 rows x 4 K-groups = 64) come from 64 lanes, each contributing one byte.
Lane mapping: lane_row = lane % 16 (matrix row), lane_k_group = lane / 16 (which 32-element K-group within the row, 0..3). Each lane loads scale_ptr[row * stride + base_k + lane_k_group].
The scale byte is placed in byte 0 of an Int32 word passed to the MFMA intrinsic (byte_index=0 / OPSEL=0).
Entry point: mxfp4_block_scaled_matmul_amd()
comptime values
MX_BLOCK_SIZE
comptime MX_BLOCK_SIZE = 32
MXFP4_BK_BYTES
comptime MXFP4_BK_BYTES = 64
MXFP4_BK_ELEMS
comptime MXFP4_BK_ELEMS = 128
MXFP4_BM
comptime MXFP4_BM = 128
MXFP4_BN
comptime MXFP4_BN = 128
MXFP4_MMA_K
comptime MXFP4_MMA_K = 128
MXFP4_MMA_M
comptime MXFP4_MMA_M = 16
MXFP4_MMA_N
comptime MXFP4_MMA_N = 16
MXFP4_NUM_M_MMAS
comptime MXFP4_NUM_M_MMAS = 4
MXFP4_NUM_N_MMAS
comptime MXFP4_NUM_N_MMAS = 4
MXFP4_NUM_THREADS
comptime MXFP4_NUM_THREADS = (4 * WARP_SIZE)
MXFP4_NUM_WARPS
comptime MXFP4_NUM_WARPS = 4
MXFP4_NUM_WARPS_M
comptime MXFP4_NUM_WARPS_M = 2
MXFP4_NUM_WARPS_N
comptime MXFP4_NUM_WARPS_N = 2
MXFP4_WM
comptime MXFP4_WM = 64
MXFP4_WN
comptime MXFP4_WN = 64
Structs
-
BlockScaledMmaOp: Register ownership + GMEM loading + block-scaled MFMA execution. -
MXFP4MatmulAMD: Native MXFP4 block-scaled matmul for AMD CDNA4.
Functions
-
mxfp4_block_scaled_matmul_amd: Launch native MXFP4 block-scaled matmul on AMD CDNA4.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!