Skip to main content

Mojo module

dispatch

Dispatch logic for grouped 1D-1D block-scaled SM100 matmul.

Selects optimal kernel configuration based on (N, K) shape and workload size, with parameters tuned via ablation on B200. NVFP4 gets shape-tuned three-regime dispatch; MXFP4 and MXFP8 use default configs.

When override=True, uses the caller's AB_swapped/mma_bn/cta_group/ num_pipeline_stages directly (for ablation studies and benchmarking). When override=False (default), ignores those parameters and selects from the tuning table based on (N, K) and estimated_total_m.

NVFP4 routing (B200-tuned via ablation):

(N=512, K=7168) Kimi K2.5 TP=8 up-proj: shape-specific 2-branch override. Phase-2 ablation showed regime classifier picks suboptimal (mma_bn, cta_group) here; all regimes converge on cta_group=2, stages=6, with mma_bn=64 for decode (avg_m <= 8) and mma_bn=128 otherwise.

Other shapes: three-regime classifier keyed on avg_m = estimated_total_m / num_active_experts.

Decode (avg_m <= 8):           AB_swapped=True, mma_bn=8,  cta_group=1
Small prefill (avg_m <= 64):   AB_swapped=True, mma_bn=64, cta_group=2
Large prefill (avg_m > 64):    AB_swapped=True, mma_bn=128, cta_group=2

Tuned stages per (N, K) live at the three _dispatch_regime call sites: (N=4096, K=7168) DeepSeek-V3 up-proj, (N=7168, K=2048) DeepSeek-V3 down-proj, (N=7168, K=256) Kimi K2.5 TP=8 down-proj (large prefill only). Unknown shapes fall through to stages=auto.

comptime values​

DECODE_AVG_M​

comptime DECODE_AVG_M = 8

SMALL_PREFILL_AVG_M​

comptime SMALL_PREFILL_AVG_M = 64

Functions​