IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MXFP4MatmulAMD_PreB

struct MXFP4MatmulAMD_PreB[BM: Int = Int(64), BN: Int = Int(128), BK_ELEMS: Int = Int(512), WN: Int = Int(64), b_prefetch: Bool = False, b_cache_policy: CacheOperation = CacheOperation.ALWAYS, dram_to_lds: Bool = False, cluster_drain_sched: Bool = False, mfma_cluster: Int = Int(4), deep_prime: Bool = False]

Preshuffled-B variant of MXFP4MatmulAMD.

The preb path requires num_warps_m == 1 (no LDS staging for B = no cross-warp M-direction B reuse), so WM is structurally fixed to BM.

When b_prefetch=True, runs a depth-2 outer-K software pipeline: while the current iter's MFMAs execute, the next iter's B fragments stream from DRAM into the alternate b_reg slot. Doubles _b_reg size (extra VGPRs) but hides DRAM B latency across the inner MFMA chain. Targets K-heavy shapes (e.g. gate/up, K=7168) where outer-iter serialization dominates.

cluster_drain_sched (b_prefetch only) stage1 inner-loop interleave: per-cluster s_setprio bracketing each mfma_cluster MFMAs (not one coarse bracket) and a partial-vmcnt staircase that keeps the prefetched B loads in flight per cluster instead of one full drain. Default off β€” existing callers are bit-identical.

deep_prime (b_prefetch only, num_tiles >= 2) deepens the A pipeline to 2-tiles-ahead: the prologue stages BOTH tile0 -> slot0 and tile1 -> slot1 into LDS so each steady iter reads an A tile that has had a full extra iteration of MFMA shadow to land. Iter i reads slot[i%2] and issues the A DMA for tile i+2 into that same (just-freed) slot. Reuses the existing num_a_slots=2 LDS buffers β€” no extra LDS/VGPR. Composes with cluster_drain_sched/mfma_cluster (the MFMA chain is unchanged). Falls back to the 1-deep path when num_tiles < 2. Default off β€” existing callers are bit-identical.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BK_BYTES​

comptime BK_BYTES = (BK_ELEMS // Int(2))

c_frag_size​

comptime c_frag_size = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.c_frag_size

MMA_K​

comptime MMA_K = 128

MMA_K_BYTES​

comptime MMA_K_BYTES = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.MMA_K_BYTES

MMA_M​

comptime MMA_M = 16

MMA_N​

comptime MMA_N = 16

MmaOpType​

comptime MmaOpType = BlockScaledMmaOp_PreB[IndexList(Int(16), Int(16), Int(128), __list_literal__=NoneType(None)), IndexList(MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].WM, WN, BK_ELEMS, __list_literal__=NoneType(None)), Int(2) if b_prefetch else Int(1)]

num_a_slots​

comptime num_a_slots = Int(2) if b_prefetch else Int(1)

num_b_slots​

comptime num_b_slots = Int(2) if b_prefetch else Int(1)

num_k_mmas​

comptime num_k_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.num_k_mmas

num_m_mmas​

comptime num_m_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.num_m_mmas

num_n_mmas​

comptime num_n_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.num_n_mmas

num_threads​

comptime num_threads = (MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].num_warps * _resolve_warp_size())

num_warps​

comptime num_warps = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].num_warps_n

num_warps_m​

comptime num_warps_m = 1

num_warps_n​

comptime num_warps_n = (BN // WN)

simd_width​

comptime simd_width = simd_width_of[DType.uint8]()

WM​

comptime WM = BM

Methods​

run​

static def run[out_dtype: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_pre_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout, N: Int, K_BYTES: Int](c: TileTensor[out_dtype, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b_pre: TileTensor[DType.uint8, b_pre_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin], n_tile_idx: Int, m_tile_idx: Int)