Skip to main content

Mojo function

single_group_router

single_group_router[scores_type: DType, bias_type: DType, //, n_routed_experts: Int, n_experts_per_tok: Int, norm_weights: Bool, target: StringSlice[StaticConstantOrigin], scores_input_fn: OptionalReg[def[width: Int](IndexList[2]) capturing -> SIMD[scores_type, width]] = None](expert_indices: TileTensor[DType.int32, expert_indices.LayoutType, expert_indices.origin, address_space=expert_indices.address_space, linear_idx_type=expert_indices.linear_idx_type, element_size=expert_indices.element_size], expert_weight: TileTensor[scores_type, expert_weight.LayoutType, expert_weight.origin, address_space=expert_weight.address_space, linear_idx_type=expert_weight.linear_idx_type, element_size=expert_weight.element_size], expert_scores: TileTensor[scores_type, expert_scores.LayoutType, expert_scores.origin, address_space=expert_scores.address_space, linear_idx_type=expert_scores.linear_idx_type, element_size=expert_scores.element_size], expert_bias: TileTensor[bias_type, expert_bias.LayoutType, expert_bias.origin, address_space=expert_bias.address_space, linear_idx_type=expert_bias.linear_idx_type, element_size=expert_bias.element_size], routed_scaling_factor: Float32, context: DeviceContextPtr)

Launch the single-group MoE router on GPU.

One block per token, one thread per expert. Selects top n_experts_per_tok experts using warp-bitonic sort with 2 or 3 reduction phases depending on hardware warp size (AMD skips phase 2 at compile time).

Inputs: expert_indices: Output expert indices. Shape: [num_tokens, n_experts_per_tok]. expert_weights: Output expert weights. Shape: [num_tokens, n_experts_per_tok]. expert_scores: Input routing scores. Shape: [num_tokens, n_routed_experts]. expert_bias: Per-expert correction bias used for selection only. routed_scaling_factor: Scalar multiplied into every output weight. context: DeviceContextPtr.

Parameters:

  • scores_type (DType): DType of routing scores and output weights.
  • bias_type (DType): DType of the expert correction bias.
  • n_routed_experts (Int): Total number of experts (e.g. 384 for Kimi K2.5).
  • n_experts_per_tok (Int): Experts selected per token — must be a power of 2 (e.g. 8 for Kimi K2.5).
  • norm_weights (Bool): If True, normalize selected weights to sum to 1 before applying routed_scaling_factor.
  • target (StringSlice): The target device to run the kernel on.
  • scores_input_fn (OptionalReg): Optional fused input lambda to load scores. If None, scores are loaded directly from expert_scores.

Was this page helpful?