Mojo struct
MXFP4MatmulAMD
struct MXFP4MatmulAMD[BM: Int = 128, BN: Int = 128, BK_ELEMS: Int = 128, WM: Int = 64, WN: Int = 64]
Native MXFP4 block-scaled matmul for AMD CDNA4.
Uses cdna4_block_scaled_mfma with FLOAT4_E2M1 format directly. Single-buffer pipeline with schedule-driven prologue/kernel/epilogue. SMEM is plain row-major (no blocked-product, no swizzle).
Parametersβ
- βBM (
Int): Block tile rows (output M per block). Default 128. - βBN (
Int): Block tile cols (output N per block). Default 128. - βBK_ELEMS (
Int): Block tile K in logical FP4 elements. Default 128. - βWM (
Int): Warp tile rows. BM must be divisible by WM. Default 64. - βWN (
Int): Warp tile cols. BN must be divisible by WN. Default 64.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
BK_BYTESβ
comptime BK_BYTES = (BK_ELEMS // 2)
c_frag_sizeβ
comptime c_frag_size = (256 // WARP_SIZE)
k_tile_sizeβ
comptime k_tile_size = MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].BK_BYTES
MMA_Kβ
comptime MMA_K = MXFP4_MMA_K
MMA_Mβ
comptime MMA_M = MXFP4_MMA_M
MMA_Nβ
comptime MMA_N = MXFP4_MMA_N
num_k_tilesβ
comptime num_k_tiles = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].BK_BYTES // 64)
num_m_mmasβ
comptime num_m_mmas = (WM // 16)
num_n_mmasβ
comptime num_n_mmas = (WN // 16)
num_threadsβ
comptime num_threads = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].num_warps * WARP_SIZE)
num_warpsβ
comptime num_warps = (MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].num_warps_m * MXFP4MatmulAMD[BM, BN, BK_ELEMS, WM, WN].num_warps_n)
num_warps_mβ
comptime num_warps_m = (BM // WM)
num_warps_nβ
comptime num_warps_n = (BN // WN)
packed_k_per_mmaβ
comptime packed_k_per_mma = 64
scales_per_mmaβ
comptime scales_per_mma = 4
simd_widthβ
comptime simd_width = simd_width_of[DType.uint8]()
Methodsβ
runβ
static run[out_dtype: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout](c: TileTensor[out_dtype, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b: TileTensor[DType.uint8, b_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin])
MXFP4 block-scaled GEMM kernel with SMEM pipeline.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!