Skip to main content

Python function

forward_moe_sharded_layers

forward_moe_sharded_layers()โ€‹

max.nn.forward_moe_sharded_layers(shards, xs)

source

Forward pass through DP-sharded layers (EP MoE or replicated MLP/MoE).

For EP-enabled MoE shards this runs the full expert-parallel communication path (dispatch -> local compute -> combine). For everything else (replicated MLP, non-EP MoE) it falls back to forward_sharded_layers().

Parameters:

Returns:

Output tensors, one per shard.

Return type:

list[TensorValue]