IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

single_group_router

def single_group_router[scores_type: DType, bias_type: DType, //, n_routed_experts: Int, n_experts_per_tok: Int, norm_weights: Bool, target: StringSlice[StaticConstantOrigin], scores_input_fn: OptionalReg[def[width: Int](IndexList[2]) capturing -> SIMD[scores_type, width]] = None](expert_indices: TileTensor[DType.int32, address_space=expert_indices.address_space, linear_idx_type=expert_indices.linear_idx_type, element_size=expert_indices.element_size], expert_weight: TileTensor[scores_type, address_space=expert_weight.address_space, linear_idx_type=expert_weight.linear_idx_type, element_size=expert_weight.element_size], expert_scores: TileTensor[scores_type, address_space=expert_scores.address_space, linear_idx_type=expert_scores.linear_idx_type, element_size=expert_scores.element_size], expert_bias: TileTensor[bias_type, address_space=expert_bias.address_space, linear_idx_type=expert_bias.linear_idx_type, element_size=expert_bias.element_size], routed_scaling_factor: Float32, context: DeviceContext)

Launch the single-group MoE router on GPU.

One block per token, one thread per expert. Selects top n_experts_per_tok experts using warp-bitonic sort with 2 or 3 reduction phases depending on hardware warp size (AMD skips phase 2 at compile time).

Inputs: expert_indices: Output expert indices. Shape: [num_tokens, n_experts_per_tok]. expert_weights: Output expert weights. Shape: [num_tokens, n_experts_per_tok]. expert_scores: Input routing scores. Shape: [num_tokens, n_routed_experts]. expert_bias: Per-expert correction bias used for selection only. routed_scaling_factor: Scalar multiplied into every output weight. context: The device context.

Parameters:

  • ​scores_type (DType): DType of routing scores and output weights.
  • ​bias_type (DType): DType of the expert correction bias.
  • ​n_routed_experts (Int): Total number of experts (e.g. 384 for Kimi K2.5).
  • ​n_experts_per_tok (Int): Experts selected per token β€” must be a power of 2 (e.g. 8 for Kimi K2.5).
  • ​norm_weights (Bool): If True, normalize selected weights to sum to 1 before applying routed_scaling_factor.
  • ​target (StringSlice[StaticConstantOrigin]): The target device to run the kernel on.
  • ​scores_input_fn (OptionalReg[def[width: Int](IndexList[2]) capturing -> SIMD[scores_type, width]]): Optional fused input lambda to load scores. If None, scores are loaded directly from expert_scores.