Python function
forward_moe_sharded_layers
forward_moe_sharded_layers()โ
max.nn.forward_moe_sharded_layers(shards, xs)
Forward pass through DP-sharded layers (EP MoE or replicated MLP/MoE).
For EP-enabled MoE shards this runs the full expert-parallel
communication path (dispatch -> local compute -> combine).
For everything else (replicated MLP, non-EP MoE) it falls back to
forward_sharded_layers().
-
Parameters:
-
- shards (Sequence[Callable[[TensorValue], TensorValue]]) โ Per-device shard callables (MoE, MLP, etc.).
- xs (list[TensorValue]) โ Input tensors, one per shard.
-
Returns:
-
Output tensors, one per shard.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!