Mojo struct
MXFP4MatmulAMD
struct MXFP4MatmulAMD
Native MXFP4 block-scaled matmul for AMD CDNA4.
Uses cdna4_block_scaled_mfma with FLOAT4_E2M1 format directly. Single-buffer pipeline with schedule-driven prologue/kernel/epilogue.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
BK_BYTES
comptime BK_BYTES = MXFP4_BK_BYTES
BK_ELEMS
comptime BK_ELEMS = MXFP4_BK_ELEMS
BM
comptime BM = MXFP4_BM
BN
comptime BN = MXFP4_BN
c_frag_size
comptime c_frag_size = (256 // WARP_SIZE)
MMA_K
comptime MMA_K = MXFP4_MMA_K
MMA_M
comptime MMA_M = MXFP4_MMA_M
MMA_N
comptime MMA_N = MXFP4_MMA_N
num_k_tiles
comptime num_k_tiles = 1
num_m_mmas
comptime num_m_mmas = MXFP4_NUM_M_MMAS
num_n_mmas
comptime num_n_mmas = MXFP4_NUM_N_MMAS
num_threads
comptime num_threads = MXFP4_NUM_THREADS
num_warps_m
comptime num_warps_m = MXFP4_NUM_WARPS_M
num_warps_n
comptime num_warps_n = MXFP4_NUM_WARPS_N
packed_k_per_mma
comptime packed_k_per_mma = 64
WM
comptime WM = MXFP4_WM
WN
comptime WN = MXFP4_WN
Methods
run
static run[c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout](c: TileTensor[DType.float32, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b: TileTensor[DType.uint8, b_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin])
MXFP4 block-scaled GEMM kernel.
All inputs/outputs are TileTensor. A and B are packed uint8 (K//2 columns). Scales are [rows, K//32] float8_e8m0fnu.
Args:
- c (
TileTensor): Output [M, N] float32. - a (
TileTensor): Packed A [M, K//2] uint8. - b (
TileTensor): Packed B [N, K//2] uint8 (transposed). - sfa (
TileTensor): A scales [M, K//32] float8_e8m0fnu. - sfb (
TileTensor): B scales [N, K//32] float8_e8m0fnu.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!