Skip to main content

Mojo module

dispatch

Dispatch logic for grouped 1D-1D block-scaled SM100 matmul.

Selects optimal kernel configuration based on (N, K) shape and workload size, with parameters tuned via ablation on B200. NVFP4 gets shape-tuned three-regime dispatch; MXFP4 and MXFP8 use default configs.

When override=True, uses the caller's AB_swapped/mma_bn/cta_group/ num_pipeline_stages directly (for ablation studies and benchmarking). When override=False (default), ignores those parameters and selects from the tuning table based on (N, K) and estimated_total_m.

NVFP4 three regimes keyed on avg_m = estimated_total_m / num_active_experts:

Decode (avg_m <= 8):

  • N=4096, K=7168: AB_swapped=True, mma_bn=8, cta_group=1, stages=6
  • N=7168, K=2048: AB_swapped=True, mma_bn=8, cta_group=1, stages=4
  • Default: AB_swapped=True, mma_bn=8, cta_group=1, stages=auto

Small prefill (8 < avg_m <= 64):

  • N=4096, K=7168: AB_swapped=True, mma_bn=64, cta_group=2, stages=6
  • N=7168, K=2048: AB_swapped=True, mma_bn=64, cta_group=2, stages=6
  • Default: AB_swapped=True, mma_bn=64, cta_group=2, stages=auto

Large prefill (avg_m > 64):

  • N=4096, K=7168: AB_swapped=True, mma_bn=128, cta_group=2, stages=7
  • N=7168, K=2048: AB_swapped=True, mma_bn=128, cta_group=2, stages=6
  • Default: AB_swapped=True, mma_bn=128, cta_group=2, stages=auto

Functionsโ€‹

Was this page helpful?