Python class
MoEGate
MoEGateβ
class max.nn.MoEGate(devices, hidden_dim, num_experts, num_experts_per_token, dtype, is_sharding=False, linear_cls=<class 'max.nn.linear.Linear'>)
Bases: Module
Gate module for MoE.
Args: devices: List of devices to use for the MoEGate. hidden_dim: The dimension of the hidden state. num_experts: The number of experts. num_experts_per_token: The number of experts per token. dtype: The data type of the MoEGate.
-
Parameters:
shard()β
shard(devices)
Create sharded views of this MoEGate module across multiple devices.
sharding_strategyβ
property sharding_strategy: ShardingStrategy | None
Get the sharding strategy for the module.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!