IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MXFP4MatmulAMD

struct MXFP4MatmulAMD[BM: Int = 128, BN: Int = 128, BK_ELEMS: Int = 128, WM: Int = 64, WN: Int = 64, MMA_M: Int = 16, MMA_N: Int = 16, MMA_K: Int = 128]

Native MXFP4 block-scaled matmul for AMD CDNA4.

Uses cdna4_block_scaled_mfma with FLOAT4_E2M1 format directly. Single-buffer pipeline with schedule-driven prologue/kernel/epilogue. SMEM is plain row-major (no blocked-product, no swizzle).

Parameters​

  • ​BM (Int): Block tile rows (output M per block). Default 128.
  • ​BN (Int): Block tile cols (output N per block). Default 128.
  • ​BK_ELEMS (Int): Block tile K in logical FP4 elements. Default 128.
  • ​WM (Int): Warp tile rows. BM must be divisible by WM. Default 64.
  • ​WN (Int): Warp tile cols. BN must be divisible by WN. Default 64.
  • ​MMA_M (Int): MFMA tile rows. WM must be divisible by MMA_M. Default 16.
  • ​MMA_N (Int): MFMA tile cols. WN must be divisible by MMA_N. Default 16.
  • ​MMA_K (Int): MFMA K-depth in logical FP4 elements. Default 128.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BK_BYTES​

comptime BK_BYTES = (BK_ELEMS // 2)

c_frag_size​

comptime c_frag_size = ((MMA_M * MMA_N) // WARP_SIZE)

k_tile_size​

comptime k_tile_size = MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN, MMA_M, MMA_N, MMA_K].BK_BYTES

num_k_tiles​

comptime num_k_tiles = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN, MMA_M, MMA_N, MMA_K].BK_BYTES // MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN, MMA_M, MMA_N, MMA_K].packed_k_per_mma)

num_m_mmas​

comptime num_m_mmas = (WM // MMA_M)

num_n_mmas​

comptime num_n_mmas = (WN // MMA_N)

num_threads​

comptime num_threads = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN, MMA_M, MMA_N, MMA_K].num_warps * WARP_SIZE)

num_warps​

comptime num_warps = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN, MMA_M, MMA_N, MMA_K].num_warps_m * MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN, MMA_M, MMA_N, MMA_K].num_warps_n)

num_warps_m​

comptime num_warps_m = (BM // WM)

num_warps_n​

comptime num_warps_n = (BN // WN)

packed_k_per_mma​

comptime packed_k_per_mma = (MMA_K // 2)

scales_per_mma​

comptime scales_per_mma = (MMA_K // 32)

simd_width​

comptime simd_width = simd_width_of[DType.uint8]()

Methods​

run​

static def run[out_dtype: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout, num_splits: Int = 1](c: TileTensor[out_dtype, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b: TileTensor[DType.uint8, b_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin])

MXFP4 block-scaled GEMM kernel with SMEM pipeline.

With num_splits > 1 this is the inter-block split-K body: each block_idx.z slice accumulates one disjoint K-band into its own [M, N] region of a stacked (num_splits * M, N) float32 workspace (out_dtype is float32 in that mode). A separate reduce kernel sums the num_splits partials and casts to the real output dtype. num_splits == 1 is byte-identical to the no-split path (split_id == 0, full K range, zero offset).