For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MXFP4MatmulAMD_PreB
struct MXFP4MatmulAMD_PreB[BM: Int = Int(64), BN: Int = Int(128), BK_ELEMS: Int = Int(512), WN: Int = Int(64), b_prefetch: Bool = False, b_cache_policy: CacheOperation = CacheOperation.ALWAYS, dram_to_lds: Bool = False, cluster_drain_sched: Bool = False, mfma_cluster: Int = Int(4), deep_prime: Bool = False]
Preshuffled-B variant of MXFP4MatmulAMD.
The preb path requires num_warps_m == 1 (no LDS staging for B = no
cross-warp M-direction B reuse), so WM is structurally fixed to BM.
When b_prefetch=True, runs a depth-2 outer-K software pipeline: while
the current iter's MFMAs execute, the next iter's B fragments stream
from DRAM into the alternate b_reg slot. Doubles _b_reg size (extra
VGPRs) but hides DRAM B latency across the inner MFMA chain. Targets
K-heavy shapes (e.g. gate/up, K=7168) where outer-iter serialization
dominates.
cluster_drain_sched (b_prefetch only) stage1 inner-loop
interleave: per-cluster s_setprio bracketing each mfma_cluster MFMAs
(not one coarse bracket) and a partial-vmcnt staircase that keeps the
prefetched B loads in flight per cluster instead of one full drain.
Default off β existing callers are bit-identical.
deep_prime (b_prefetch only, num_tiles >= 2) deepens the A pipeline to
2-tiles-ahead: the prologue stages BOTH tile0 -> slot0 and tile1 -> slot1
into LDS so each steady iter reads an A tile that has had a full extra
iteration of MFMA shadow to land. Iter i reads slot[i%2] and issues the
A DMA for tile i+2 into that same (just-freed) slot. Reuses the existing
num_a_slots=2 LDS buffers β no extra LDS/VGPR. Composes with cluster_drain_sched/mfma_cluster
(the MFMA chain is unchanged). Falls back to the 1-deep path when num_tiles < 2.
Default off β existing callers are bit-identical.
Implemented traitsβ
comptime membersβ
BK_BYTESβ
comptime BK_BYTES = (BK_ELEMS // Int(2))
c_frag_sizeβ
comptime c_frag_size = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.c_frag_size
MMA_Kβ
comptime MMA_K = 128
MMA_K_BYTESβ
comptime MMA_K_BYTES = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.MMA_K_BYTES
MMA_Mβ
comptime MMA_M = 16
MMA_Nβ
comptime MMA_N = 16
MmaOpTypeβ
comptime MmaOpType = BlockScaledMmaOp_PreB[IndexList(Int(16), Int(16), Int(128), __list_literal__=NoneType(None)), IndexList(MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].WM, WN, BK_ELEMS, __list_literal__=NoneType(None)), Int(2) if b_prefetch else Int(1)]
num_a_slotsβ
comptime num_a_slots = Int(2) if b_prefetch else Int(1)
num_b_slotsβ
comptime num_b_slots = Int(2) if b_prefetch else Int(1)
num_k_mmasβ
comptime num_k_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.num_k_mmas
num_m_mmasβ
comptime num_m_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.num_m_mmas
num_n_mmasβ
comptime num_n_mmas = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].MmaOpType.num_n_mmas
num_threadsβ
comptime num_threads = (MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].num_warps * _resolve_warp_size())
num_warpsβ
comptime num_warps = MXFP4MatmulAMD_PreB[BM, BN, BK_ELEMS, WN, b_prefetch, b_cache_policy, dram_to_lds, cluster_drain_sched, mfma_cluster, deep_prime].num_warps_n
num_warps_mβ
comptime num_warps_m = 1
num_warps_nβ
comptime num_warps_n = (BN // WN)
simd_widthβ
comptime simd_width = simd_width_of[DType.uint8]()
WMβ
comptime WM = BM
Methodsβ
runβ
static def run[out_dtype: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_pre_layout: TensorLayout, sfa_layout: TensorLayout, sfb_layout: TensorLayout, N: Int, K_BYTES: Int](c: TileTensor[out_dtype, c_layout, MutAnyOrigin], a: TileTensor[DType.uint8, a_layout, ImmutAnyOrigin], b_pre: TileTensor[DType.uint8, b_pre_layout, ImmutAnyOrigin], sfa: TileTensor[DType.float8_e8m0fnu, sfa_layout, ImmutAnyOrigin], sfb: TileTensor[DType.float8_e8m0fnu, sfb_layout, ImmutAnyOrigin], n_tile_idx: Int, m_tile_idx: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!