Skip to main content

Python class

MoEGate

MoEGate​

class max.nn.MoEGate(devices, hidden_dim, num_experts, num_experts_per_token, dtype, is_sharding=False, linear_cls=<class 'max.nn.linear.Linear'>)

source

Bases: Module

Gate module for MoE.

Args: devices: List of devices to use for the MoEGate. hidden_dim: The dimension of the hidden state. num_experts: The number of experts. num_experts_per_token: The number of experts per token. dtype: The data type of the MoEGate.

Parameters:

shard()​

shard(devices)

source

Create sharded views of this MoEGate module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MoEGate instances, one for each device.

Return type:

Sequence[MoEGate]

sharding_strategy​

property sharding_strategy: ShardingStrategy | None

source

Get the sharding strategy for the module.