IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MXFP4MatmulAMD_PreB

struct MXFP4MatmulAMD_PreB[BM: Int = 64, BN: Int = 128, BK_ELEMS: Int = 512, WN: Int = 64, B_PREFETCH: Bool = False]

Preshuffled-B variant of MXFP4MatmulAMD.

The preb path requires num_warps_m == 1 (no LDS staging for B = no cross-warp M-direction B reuse), so WM is structurally fixed to BM.

When B_PREFETCH=True, runs a depth-2 outer-K software pipeline: while the current iter's MFMAs execute, the next iter's B fragments stream from DRAM into the alternate b_reg slot. Doubles _b_reg size (extra VGPRs) but hides DRAM B latency across the inner MFMA chain. Targets K-heavy shapes (e.g. gate/up, K=7168) where outer-iter serialization dominates.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BK_BYTES​

comptime BK_BYTES = (BK_ELEMS // 2)

c_frag_size​

comptime c_frag_size = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.c_frag_size

MMA_K​

comptime MMA_K = 128

MMA_K_BYTES​

comptime MMA_K_BYTES = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.MMA_K_BYTES

MMA_M​

comptime MMA_M = 16

MMA_N​

comptime MMA_N = 16

MmaOpType​

comptime MmaOpType = BlockScaledMmaOp_PreB[IndexList(16, 16, 128, __list_literal__=NoneType(None)), IndexList(MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].WM, WN, BK_ELEMS, __list_literal__=NoneType(None)), MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].num_b_slots]

num_b_slots​

comptime num_b_slots = 2 if B_PREFETCH else 1

num_k_mmas​

comptime num_k_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.num_k_mmas

num_m_mmas​

comptime num_m_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.num_m_mmas

num_n_mmas​

comptime num_n_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.num_n_mmas

num_threads​

comptime num_threads = (MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].num_warps * WARP_SIZE)

num_warps​

comptime num_warps = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].num_warps_n

num_warps_m​

comptime num_warps_m = 1

num_warps_n​

comptime num_warps_n = (BN // WN)

simd_width​

comptime simd_width = simd_width_of[DType.uint8]()

WM​

comptime WM = BM

Methods​

run​

static def run[out_dtype: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_pre_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout, N: Int, K_BYTES: Int](c: TileTensor[out_dtype, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b_pre: TileTensor[DType.uint8, b_pre_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin], n_tile_idx: Int, m_tile_idx: Int)