IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mxfp4_matmul_amd

Native MXFP4 block-scaled matmul on AMD CDNA4 via f8f6f4 MFMA.

Computes C = (A * scale_a) @ (B * scale_b)^T where A and B are packed MXFP4 (E2M1) in uint8 with per-block E8M0 scaling factors. Uses the CDNA4 mfma.scale.f32.16x16x128.f8f6f4 instruction which natively consumes MXFP4 operands with E8M0 scale words β€” no dequantization needed.

Structure mirrors AMDMatmul: TileTensor throughout, RegTileLoader for DRAM→regs, row-major SMEM (no blocked-product or swizzle — the FP4 MFMA expects a simple row-major lane-to-data mapping unlike BF16/FP8), schedule-driven pipeline.

MXFP4 data layout: A: [M, K//2] uint8 (two MXFP4 nibbles packed per byte), row-major B: [N, K//2] uint8, row-major (transposed: each row is one output column) scale_a: [M, K//32] float8_e8m0fnu (one scale per 32 MXFP4 elements) scale_b: [N, K//32] float8_e8m0fnu

MFMA lane-to-data mapping for 16x16x128 FP4: Each lane loads 16 contiguous bytes from its assigned matrix row. lane_row = lane_id % MMA_M, lane_chunk = lane_id / MMA_M. Offset = lane_row * row_stride + lane_chunk * 16. The 16 bytes are zero-extended to SIMD[uint8, 32] for the MFMA operand.

MFMA scale model (16x16x128): Each lane holds 16x128/64 = 32 FP4 elements and one E8M0 scale. This matches the MX format exactly: one scale per 32 elements. The 64 scale values (16 rows x 4 K-groups = 64) come from 64 lanes, each contributing one byte.

Lane mapping: lane_row = lane % 16 (matrix row), lane_k_group = lane / 16 (which 32-element K-group within the row, 0..3). Each lane loads scale_ptr[row * stride + base_k + lane_k_group].

The scale byte is placed in byte 0 of an Int32 word passed to the MFMA intrinsic (byte_index=0 / OPSEL=0).

Entry point: mxfp4_block_scaled_matmul_amd()

comptime values​

MX_BLOCK_SIZE​

comptime MX_BLOCK_SIZE = 32

SCHED_MASK_DS_READ​

comptime SCHED_MASK_DS_READ = 0

SCHED_MASK_DS_WRITE​

comptime SCHED_MASK_DS_WRITE = 1

SCHED_MASK_MFMA​

comptime SCHED_MASK_MFMA = 3

SCHED_MASK_VMEM_READ​

comptime SCHED_MASK_VMEM_READ = 2

Structs​

Functions​