IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

dispatch

Dispatch logic for grouped 1D-1D block-scaled SM100 matmul.

Selects optimal kernel configuration based on (N, K) shape and workload size, with parameters tuned via ablation on B200. NVFP4 gets shape-tuned three-regime dispatch; MXFP4 and MXFP8 use default configs.

When override=True, uses the caller's AB_swapped/mma_bn/cta_group/ num_pipeline_stages directly (for ablation studies and benchmarking). When override=False (default), ignores those parameters and selects from the tuning table based on (N, K) and estimated_total_m.

NVFP4 routing (B200-tuned via ablation):

(N=512, K=7168) Kimi K2.5 TP=8 up-proj: shape-specific 2-branch override. Phase-2 ablation showed regime classifier picks suboptimal (mma_bn, cta_group) here; all regimes converge on cta_group=2, stages=6, with mma_bn=64 for decode (avg_m <= 8) and mma_bn=128 otherwise.

Other shapes: three-regime classifier keyed on avg_m = estimated_total_m / num_active_experts.

Decode (avg_m <= 8):           AB_swapped=True, mma_bn=8,  cta_group=1
Small prefill (avg_m <= 64):   AB_swapped=True, mma_bn=64, cta_group=2
Large prefill (avg_m > 64):    AB_swapped=True, mma_bn=128, cta_group=2

Tuned stages per (N, K) live at the three _dispatch_regime call sites: (N=4096, K=7168) DeepSeek-V3 up-proj, (N=7168, K=2048) DeepSeek-V3 down-proj, (N=7168, K=256) Kimi K2.5 TP=8 down-proj (large prefill only). Unknown shapes fall through to stages=auto.

comptime values​

DECODE_AVG_M​

comptime DECODE_AVG_M = 8

SMALL_PREFILL_AVG_M​

comptime SMALL_PREFILL_AVG_M = 64

Functions​