Skip to main content

Mojo module

mxfp4_matmul_amd

Native MXFP4 block-scaled matmul on AMD CDNA4 via f8f6f4 MFMA.

Computes C = (A * scale_a) @ (B * scale_b)^T where A and B are packed MXFP4 (E2M1) in uint8 with per-block E8M0 scaling factors. Uses the CDNA4 mfma.scale.f32.16x16x128.f8f6f4 instruction which natively consumes MXFP4 operands with E8M0 scale words — no dequantization needed.

Structure mirrors AMDMatmul: TileTensor throughout, RegTileLoader for DRAM→regs, row-major SMEM (no blocked-product or swizzle — the FP4 MFMA expects a simple row-major lane-to-data mapping unlike BF16/FP8), schedule-driven pipeline.

MXFP4 data layout: A: [M, K//2] uint8 (two MXFP4 nibbles packed per byte), row-major B: [N, K//2] uint8, row-major (transposed: each row is one output column) scale_a: [M, K//32] float8_e8m0fnu (one scale per 32 MXFP4 elements) scale_b: [N, K//32] float8_e8m0fnu

MFMA lane-to-data mapping for 16x16x128 FP4: Each lane loads 16 contiguous bytes from its assigned matrix row. lane_row = lane_id % MMA_M, lane_chunk = lane_id / MMA_M. Offset = lane_row * row_stride + lane_chunk * 16. The 16 bytes are zero-extended to SIMD[uint8, 32] for the MFMA operand.

MFMA scale model (16x16x128): Each lane holds 16x128/64 = 32 FP4 elements and one E8M0 scale. This matches the MX format exactly: one scale per 32 elements. The 64 scale values (16 rows x 4 K-groups = 64) come from 64 lanes, each contributing one byte.

Lane mapping: lane_row = lane % 16 (matrix row), lane_k_group = lane / 16 (which 32-element K-group within the row, 0..3). Each lane loads scale_ptr[row * stride + base_k + lane_k_group].

The scale byte is placed in byte 0 of an Int32 word passed to the MFMA intrinsic (byte_index=0 / OPSEL=0).

Entry point: mxfp4_block_scaled_matmul_amd()

comptime values

MX_BLOCK_SIZE

comptime MX_BLOCK_SIZE = 32

MXFP4_BK_BYTES

comptime MXFP4_BK_BYTES = 64

MXFP4_BK_ELEMS

comptime MXFP4_BK_ELEMS = 128

MXFP4_BM

comptime MXFP4_BM = 128

MXFP4_BN

comptime MXFP4_BN = 128

MXFP4_MMA_K

comptime MXFP4_MMA_K = 128

MXFP4_MMA_M

comptime MXFP4_MMA_M = 16

MXFP4_MMA_N

comptime MXFP4_MMA_N = 16

MXFP4_NUM_M_MMAS

comptime MXFP4_NUM_M_MMAS = 4

MXFP4_NUM_N_MMAS

comptime MXFP4_NUM_N_MMAS = 4

MXFP4_NUM_THREADS

comptime MXFP4_NUM_THREADS = (4 * WARP_SIZE)

MXFP4_NUM_WARPS

comptime MXFP4_NUM_WARPS = 4

MXFP4_NUM_WARPS_M

comptime MXFP4_NUM_WARPS_M = 2

MXFP4_NUM_WARPS_N

comptime MXFP4_NUM_WARPS_N = 2

MXFP4_WM

comptime MXFP4_WM = 64

MXFP4_WN

comptime MXFP4_WN = 64

Structs

Functions

Was this page helpful?