Mojo function
router_group_limited
router_group_limited[scores_type: DType, bias_type: DType, //, n_routed_experts: Int, n_experts_per_tok: Int, n_groups: Int, topk_group: Int, norm_weights: Bool, target: StringSlice[StaticConstantOrigin]](expert_indices: LayoutTensor[DType.int32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_weights: LayoutTensor[scores_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_scores: LayoutTensor[scores_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_bias: LayoutTensor[bias_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], routed_scaling_factor: Float32, context: DeviceContextPtr)
A manually fused MoE router with the group-limited strategy.
Inputs: expert_indices: The indices of the routed experts for each token. Shape: [num_tokens, num_experts_per_tok]. expert_weights: The weights of the routed experts for each token. Shape: [num_tokens, num_experts_per_tok]. expert_scores: The scores for each expert for each token. Shape: [num_tokens, n_routed_experts]. expert_bias: The bias for each expert. Shape: [n_routed_experts]. routed_scaling_factor: The scaling factor for the routed expert weights. context: DeviceContextPtr.
Parameters:
- scores_type (
DType): The data type of the scores and the output weights. - bias_type (
DType): The data type of the expert bias. - n_routed_experts (
Int): The number of experts to route to. - n_experts_per_tok (
Int): The number of experts to be selected per token. - n_groups (
Int): The number of expert groups. - topk_group (
Int): The number of expert groups to be selected per token. - norm_weights (
Bool): Whether to normalize the selected weights. - target (
StringSlice): The target device to run the kernel on.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!