Skip to main content

Mojo function

router_group_limited

router_group_limited[scores_type: DType, bias_type: DType, //, n_routed_experts: Int, n_experts_per_tok: Int, n_groups: Int, topk_group: Int, norm_weights: Bool, target: StringSlice[StaticConstantOrigin]](expert_indices: LayoutTensor[DType.int32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_weights: LayoutTensor[scores_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_scores: LayoutTensor[scores_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_bias: LayoutTensor[bias_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], routed_scaling_factor: Float32, context: DeviceContextPtr)

A manually fused MoE router with the group-limited strategy.

Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/9b4e9788e4a3a731f7567338ed15d3ec549ce03b/inference/model.py#L566.

Inputs: expert_indices: The indices of the routed experts for each token. Shape: [num_tokens, num_experts_per_tok]. expert_weights: The weights of the routed experts for each token. Shape: [num_tokens, num_experts_per_tok]. expert_scores: The scores for each expert for each token. Shape: [num_tokens, n_routed_experts]. expert_bias: The bias for each expert. Shape: [n_routed_experts]. routed_scaling_factor: The scaling factor for the routed expert weights. context: DeviceContextPtr.

Parameters:

  • scores_type (DType): The data type of the scores and the output weights.
  • bias_type (DType): The data type of the expert bias.
  • n_routed_experts (Int): The number of experts to route to.
  • n_experts_per_tok (Int): The number of experts to be selected per token.
  • n_groups (Int): The number of expert groups.
  • topk_group (Int): The number of expert groups to be selected per token.
  • norm_weights (Bool): Whether to normalize the selected weights.
  • target (StringSlice): The target device to run the kernel on.

Was this page helpful?