For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MXFP4MatmulAMD_PreB
struct MXFP4MatmulAMD_PreB[BM: Int = 64, BN: Int = 128, BK_ELEMS: Int = 512, WN: Int = 64, B_PREFETCH: Bool = False]
Preshuffled-B variant of MXFP4MatmulAMD.
The preb path requires num_warps_m == 1 (no LDS staging for B = no
cross-warp M-direction B reuse), so WM is structurally fixed to BM.
When B_PREFETCH=True, runs a depth-2 outer-K software pipeline: while
the current iter's MFMAs execute, the next iter's B fragments stream
from DRAM into the alternate b_reg slot. Doubles _b_reg size (extra
VGPRs) but hides DRAM B latency across the inner MFMA chain. Targets
K-heavy shapes (e.g. gate/up, K=7168) where outer-iter serialization
dominates.
Implemented traitsβ
comptime membersβ
BK_BYTESβ
comptime BK_BYTES = (BK_ELEMS // 2)
c_frag_sizeβ
comptime c_frag_size = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.c_frag_size
MMA_Kβ
comptime MMA_K = 128
MMA_K_BYTESβ
comptime MMA_K_BYTES = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.MMA_K_BYTES
MMA_Mβ
comptime MMA_M = 16
MMA_Nβ
comptime MMA_N = 16
MmaOpTypeβ
comptime MmaOpType = BlockScaledMmaOp_PreB[IndexList(16, 16, 128, __list_literal__=NoneType(None)), IndexList(MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].WM, WN, BK_ELEMS, __list_literal__=NoneType(None)), MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].num_b_slots]
num_b_slotsβ
comptime num_b_slots = 2 if B_PREFETCH else 1
num_k_mmasβ
comptime num_k_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.num_k_mmas
num_m_mmasβ
comptime num_m_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.num_m_mmas
num_n_mmasβ
comptime num_n_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].MmaOpType.num_n_mmas
num_threadsβ
comptime num_threads = (MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].num_warps * WARP_SIZE)
num_warpsβ
comptime num_warps = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, B_PREFETCH].num_warps_n
num_warps_mβ
comptime num_warps_m = 1
num_warps_nβ
comptime num_warps_n = (BN // WN)
simd_widthβ
comptime simd_width = simd_width_of[DType.uint8]()
WMβ
comptime WM = BM
Methodsβ
runβ
static def run[out_dtype: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_pre_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout, N: Int, K_BYTES: Int](c: TileTensor[out_dtype, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b_pre: TileTensor[DType.uint8, b_pre_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin], n_tile_idx: Int, m_tile_idx: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!