Skip to main content

Mojo struct

MXFP4MatmulAMD

struct MXFP4MatmulAMD

Native MXFP4 block-scaled matmul for AMD CDNA4.

Uses cdna4_block_scaled_mfma with FLOAT4_E2M1 format directly. Single-buffer pipeline with schedule-driven prologue/kernel/epilogue.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

BK_BYTES

comptime BK_BYTES = MXFP4_BK_BYTES

BK_ELEMS

comptime BK_ELEMS = MXFP4_BK_ELEMS

BM

comptime BM = MXFP4_BM

BN

comptime BN = MXFP4_BN

c_frag_size

comptime c_frag_size = (256 // WARP_SIZE)

MMA_K

comptime MMA_K = MXFP4_MMA_K

MMA_M

comptime MMA_M = MXFP4_MMA_M

MMA_N

comptime MMA_N = MXFP4_MMA_N

num_k_tiles

comptime num_k_tiles = 1

num_m_mmas

comptime num_m_mmas = MXFP4_NUM_M_MMAS

num_n_mmas

comptime num_n_mmas = MXFP4_NUM_N_MMAS

num_threads

comptime num_threads = MXFP4_NUM_THREADS

num_warps_m

comptime num_warps_m = MXFP4_NUM_WARPS_M

num_warps_n

comptime num_warps_n = MXFP4_NUM_WARPS_N

packed_k_per_mma

comptime packed_k_per_mma = 64

WM

comptime WM = MXFP4_WM

WN

comptime WN = MXFP4_WN

Methods

run

static run[c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout](c: TileTensor[DType.float32, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b: TileTensor[DType.uint8, b_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin])

MXFP4 block-scaled GEMM kernel.

All inputs/outputs are TileTensor. A and B are packed uint8 (K//2 columns). Scales are [rows, K//32] float8_e8m0fnu.

Args:

  • c (TileTensor): Output [M, N] float32.
  • a (TileTensor): Packed A [M, K//2] uint8.
  • b (TileTensor): Packed B [N, K//2] uint8 (transposed).
  • sfa (TileTensor): A scales [M, K//32] float8_e8m0fnu.
  • sfb (TileTensor): B scales [N, K//32] float8_e8m0fnu.

Was this page helpful?