Skip to main content

Mojo struct

MXFP4MatmulAMD

struct MXFP4MatmulAMD[BM: Int = 128, BN: Int = 128, BK_ELEMS: Int = 128, WM: Int = 64, WN: Int = 64]

Native MXFP4 block-scaled matmul for AMD CDNA4.

Uses cdna4_block_scaled_mfma with FLOAT4_E2M1 format directly. Single-buffer pipeline with schedule-driven prologue/kernel/epilogue. SMEM is plain row-major (no blocked-product, no swizzle).

Parameters​

  • ​BM (Int): Block tile rows (output M per block). Default 128.
  • ​BN (Int): Block tile cols (output N per block). Default 128.
  • ​BK_ELEMS (Int): Block tile K in logical FP4 elements. Default 128.
  • ​WM (Int): Warp tile rows. BM must be divisible by WM. Default 64.
  • ​WN (Int): Warp tile cols. BN must be divisible by WN. Default 64.

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

BK_BYTES​

comptime BK_BYTES = (BK_ELEMS // 2)

c_frag_size​

comptime c_frag_size = (256 // WARP_SIZE)

k_tile_size​

comptime k_tile_size = MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].BK_BYTES

MMA_K​

comptime MMA_K = MXFP4_MMA_K

MMA_M​

comptime MMA_M = MXFP4_MMA_M

MMA_N​

comptime MMA_N = MXFP4_MMA_N

num_k_tiles​

comptime num_k_tiles = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].BK_BYTES // 64)

num_m_mmas​

comptime num_m_mmas = (WM // 16)

num_n_mmas​

comptime num_n_mmas = (WN // 16)

num_threads​

comptime num_threads = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].num_warps * WARP_SIZE)

num_warps​

comptime num_warps = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].num_warps_m * MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].num_warps_n)

num_warps_m​

comptime num_warps_m = (BM // WM)

num_warps_n​

comptime num_warps_n = (BN // WN)

packed_k_per_mma​

comptime packed_k_per_mma = 64

scales_per_mma​

comptime scales_per_mma = 4

simd_width​

comptime simd_width = simd_width_of[DType.uint8]()

Methods​

run​

static run[out_dtype: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout](c: TileTensor[out_dtype, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b: TileTensor[DType.uint8, b_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin])

MXFP4 block-scaled GEMM kernel with SMEM pipeline.