Skip to main content

Mojo function

router_group_limited

router_group_limited[scores_type: DType, bias_type: DType, //, n_routed_experts: Int, n_experts_per_tok: Int, n_groups: Int, topk_group: Int, norm_weights: Bool, target: StringSlice[StaticConstantOrigin], scores_input_fn: OptionalReg[fn[width: Int](IndexList[2]) capturing -> SIMD[scores_type, width]] = None](expert_indices: TileTensor[DType.int32, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], expert_weights: TileTensor[scores_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], expert_scores: TileTensor[scores_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], expert_bias: TileTensor[bias_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], routed_scaling_factor: Float32, context: DeviceContextPtr)

A manually fused MoE router with the group-limited strategy.

Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/9b4e9788e4a3a731f7567338ed15d3ec549ce03b/inference/model.py#L566.

Inputs: expert_indices: The indices of the routed experts for each token. Shape: [num_tokens, num_experts_per_tok]. expert_weights: The weights of the routed experts for each token. Shape: [num_tokens, num_experts_per_tok]. expert_scores: The scores for each expert for each token. Shape: [num_tokens, n_routed_experts]. expert_bias: The bias for each expert. Shape: [n_routed_experts]. routed_scaling_factor: The scaling factor for the routed expert weights. context: DeviceContextPtr.

Parameters:

  • scores_type (DType): The data type of the scores and the output weights.
  • bias_type (DType): The data type of the expert bias.
  • n_routed_experts (Int): The number of experts to route to.
  • n_experts_per_tok (Int): The number of experts to be selected per token.
  • n_groups (Int): The number of expert groups.
  • topk_group (Int): The number of expert groups to be selected per token.
  • norm_weights (Bool): Whether to normalize the selected weights.
  • target (StringSlice): The target device to run the kernel on.
  • scores_input_fn (OptionalReg): Input lambda function to load the scores.

Was this page helpful?