IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

MoEGate

MoEGate​

class max.nn.MoEGate(devices, hidden_dim, num_experts, num_experts_per_token, dtype, is_sharding=False, linear_cls=<class 'max.nn.linear.Linear'>)

source

Bases: Module

Gate module for MoE.

Args: devices: List of devices to use for the MoEGate. hidden_dim: The dimension of the hidden state. num_experts: The number of experts. num_experts_per_token: The number of experts per token. dtype: The data type of the MoEGate.

Parameters:

shard()​

shard(devices)

source

Create sharded views of this MoEGate module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MoEGate instances, one for each device.

Return type:

Sequence[MoEGate]

sharding_strategy​

property sharding_strategy: ShardingStrategy | None

source

Get the sharding strategy for the module.