Skip to main content

Python class

MoE

MoE

class max.nn.MoE(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.moe.moe.MoEGate'>, mlp_cls=<class 'max.nn.linear.MLP'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False, swiglu_limit=0.0, ep_batch_manager=None, quant_config=None, is_sharding=False)

source

Bases: Module, Shardable

Implementation of Mixture of Experts (MoE).

Parameters:

  • devices (list[DeviceRef]) – The list of devices to use for the MoE.
  • hidden_dim (int) – The dimension of the hidden state.
  • num_experts (int) – The number of experts.
  • num_experts_per_token (int) – The number of experts per token.
  • moe_dim (int) – The intermediate dimension of each expert.
  • gate_cls (Callable[..., MoEGate]) – The model-specific gate implementation. Defaults to MoEGate.
  • mlp_cls (Callable[..., MLP]) – The MLP class to use for experts. Defaults to MLP.
  • has_shared_experts (bool) – Whether to use shared experts. Defaults to False.
  • shared_experts_dim (int) – The dimension of the shared experts. Defaults to 0.
  • ep_size (int) – The expert parallelism size. Defaults to 1.
  • dtype (DType) – The data type of the MoE. Defaults to DType.bfloat16.
  • apply_router_weight_first (bool) – Whether to apply the router weight first. Defaults to False.
  • ep_batch_manager (EPBatchManager | None) – The expert parallel batch manager. Defaults to None.
  • quant_config (QuantConfig | None) – The scaled quantization configuration. Defaults to None.
  • is_sharding (bool) – Whether the constructor is being called during sharding. Defaults to False.
  • swiglu_limit (float)

down_proj

property down_proj: TensorValue

source

ep_batch_manager

property ep_batch_manager: EPBatchManager

source

Get the expert parallel batch manager.

experts

experts: LayerList

source

The list of experts.

gate_up_proj

property gate_up_proj: TensorValue

source

shard()

shard(devices)

source

Create sharded views of this MoE module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MoE instances, one for each device.

Return type:

list[Self]

shard_devices

shard_devices: list[DeviceRef] = []

source

The list of devices the MoE layer was sharded to.

shard_index

shard_index: int = 0

source

The index of the current shard (if the MoE layer was sharded).

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the sharding strategy for the module.