Python module
moe
Mixture of Experts (MoE) module.
Fp8Strategy
class max.nn.legacy.moe.Fp8Strategy(config, dtype)
FP8 quantization for MoE.
-
Parameters:
-
- config (Float8Config)
- dtype (DType)
fused_silu_quantize()
fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())
Applies fused SiLU gate and returns quantized activations.
-
Parameters:
-
- gate_up_projs (TensorValue)
- input_scales (TensorValue | None)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
grouped_matmul()
grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())
Runs grouped FP8 matmul for the routed experts.
-
Parameters:
-
- weight (TensorValue)
- weight_scales (TensorValue)
- expert_scales (TensorValue | None)
- tokens_padded_per_expert (bool)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
prepare_weight_scales()
prepare_weight_scales(gate_up, down, device)
Passes FP8 weight scales through without reformatting.
-
Parameters:
-
- gate_up (TensorValue)
- down (TensorValue)
- device (DeviceRef)
-
Return type:
quantize()
quantize(tensor, group_size, input_scale=None)
Quantizes activations to FP8 and returns (quantized, scales).
-
Parameters:
-
- tensor (TensorValue)
- group_size (int)
- input_scale (TensorValue | None)
-
Return type:
GateUpFormat
class max.nn.legacy.moe.GateUpFormat(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Specifies the format of the combined gate/up projection weights.
CONCATENATED
CONCATENATED = 'concatenated'
Gate and up projections concatenated as [gate | up].
Stored as [num_experts, hidden_dim, 2 * moe_dim].
Split at moe_dim: gate = output[:, :moe_dim],
up = output[:, moe_dim:]. Used by Llama4 and Qwen3VL.
INTERLEAVED
INTERLEAVED = 'interleaved'
Gate and up projections interleaved as [g0, u0, g1, u1, ...].
Stored as [num_experts, hidden_dim, 2 * moe_dim].
Split with stride: gate = output[:, 0::2],
up = output[:, 1::2]. Used by GptOss.
MoE
class max.nn.legacy.moe.MoE(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.legacy.moe.moe.MoEGate'>, mlp_cls=<class 'max.nn.legacy.linear.MLP'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False, ep_batch_manager=None, float8_config=None, is_sharding=False)
Implementation of Mixture of Experts (MoE).
-
Parameters:
-
- devices (list[DeviceRef])
- hidden_dim (int)
- num_experts (int)
- num_experts_per_token (int)
- moe_dim (int)
- gate_cls (Callable[..., MoEGate])
- mlp_cls (Callable[..., MLP])
- has_shared_experts (bool)
- shared_experts_dim (int)
- ep_size (int)
- dtype (DType)
- apply_router_weight_first (bool)
- ep_batch_manager (EPBatchManager | None)
- float8_config (Float8Config | None)
- is_sharding (bool)
down_proj
property down_proj: TensorValue
ep_batch_manager
property ep_batch_manager: EPBatchManager
Get the expert parallel batch manager.
experts
experts: LayerList
The list of experts.
gate_up_proj
property gate_up_proj: TensorValue
shard()
shard(devices)
Create sharded views of this MoE module across multiple devices.
shard_devices
The list of devices the MoE layer was sharded to.
shard_index
shard_index: int = 0
The index of the current shard (if the MoE layer was sharded).
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the sharding strategy for the module.
MoEGate
class max.nn.legacy.moe.MoEGate(devices, hidden_dim, num_experts, num_experts_per_token, dtype, is_sharding=False, linear_cls=<class 'max.nn.legacy.linear.Linear'>)
Gate module for MoE.
-
Parameters:
shard()
shard(devices)
Create sharded views of this MoEGate module across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the sharding strategy for the module.
MoEQuantized
class max.nn.legacy.moe.MoEQuantized(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.legacy.moe.moe.MoEGate'>, mlp_cls=<class 'max.nn.legacy.linear.MLP'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False, ep_batch_manager=None, float8_config=None, is_sharding=False)
Mixture of Experts with FP8 or NVFP4 quantization.
-
Parameters:
-
- devices (list[DeviceRef])
- hidden_dim (int)
- num_experts (int)
- num_experts_per_token (int)
- moe_dim (int)
- gate_cls (Callable[..., MoEGate])
- mlp_cls (Callable[..., MLP])
- has_shared_experts (bool)
- shared_experts_dim (int)
- ep_size (int)
- dtype (DType)
- apply_router_weight_first (bool)
- ep_batch_manager (EPBatchManager | None)
- float8_config (Float8Config | None)
- is_sharding (bool)
down_proj_scales
property down_proj_scales: TensorValue
Returns stacked down-projection weight scales.
gate_up_proj_scales
property gate_up_proj_scales: TensorValue
Returns stacked gate/up weight scales for grouped matmul.
Nvfp4Scales
class max.nn.legacy.moe.Nvfp4Scales(gate_up_input, down_input, gate_up_expert, down_expert)
Bundled scales for NVFP4 quantization.
-
Parameters:
-
- gate_up_input (TensorValue)
- down_input (TensorValue)
- gate_up_expert (TensorValue)
- down_expert (TensorValue)
down_expert
down_expert: TensorValue
down_input
down_input: TensorValue
gate_up_expert
gate_up_expert: TensorValue
gate_up_input
gate_up_input: TensorValue
Nvfp4Strategy
class max.nn.legacy.moe.Nvfp4Strategy(config, dtype)
NVFP4 quantization for MoE.
-
Parameters:
-
- config (Float8Config)
- dtype (DType)
fused_silu_quantize()
fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())
Applies SiLU gate then NVFP4 quantizes the result.
-
Parameters:
-
- gate_up_projs (TensorValue)
- input_scales (TensorValue | None)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
grouped_matmul()
grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())
Runs grouped NVFP4 matmul with per-expert scales.
-
Parameters:
-
- weight (TensorValue)
- weight_scales (TensorValue)
- expert_scales (TensorValue | None)
- tokens_padded_per_expert (bool)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
prepare_weight_scales()
prepare_weight_scales(gate_up, down, device)
Interleaves NVFP4 block scales for kernel layout.
-
Parameters:
-
- gate_up (TensorValue)
- down (TensorValue)
- device (DeviceRef)
-
Return type:
quantize()
quantize(tensor, group_size, input_scale=None)
Quantizes activations to NVFP4 and returns (quantized, scales).
-
Parameters:
-
- tensor (TensorValue)
- group_size (int)
- input_scale (TensorValue | None)
-
Return type:
QuantStrategy
class max.nn.legacy.moe.QuantStrategy(*args, **kwargs)
Quantization strategy for MoE layers.
fused_silu_quantize()
fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())
Applies gating and quantizes activations for the down proj.
-
Parameters:
-
- gate_up_projs (TensorValue)
- input_scales (TensorValue | None)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
grouped_matmul()
grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())
Runs grouped matmul for routed experts.
-
Parameters:
-
- weight (TensorValue)
- weight_scales (TensorValue)
- expert_scales (TensorValue | None)
- tokens_padded_per_expert (bool)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
prepare_weight_scales()
prepare_weight_scales(gate_up, down, device)
Prepares weight scales for kernel consumption.
-
Parameters:
-
- gate_up (TensorValue)
- down (TensorValue)
- device (DeviceRef)
-
Return type:
quantize()
quantize(tensor, group_size, input_scale=None)
Quantizes activations and returns (quantized, scales).
-
Parameters:
-
- tensor (TensorValue)
- group_size (int)
- input_scale (TensorValue | None)
-
Return type:
StackedMoE
class max.nn.legacy.moe.StackedMoE(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls, dtype=bfloat16, gate_up_format=GateUpFormat.CONCATENATED, activation_fn=<function silu_activation>, has_bias=False, has_shared_experts=False, shared_experts_dim=0, float8_config=None, apply_router_weight_first=False, is_sharding=False)
Stacked Mixture of Experts layer with configurable components.
This class consolidates MoE implementations from multiple architectures (Llama4, Qwen3VL, GptOss) into a single base layer. All expert weights are stored in stacked format rather than as individual MLP experts.
Weight tensor shapes:
gate_up_proj:[num_experts, hidden_dim, 2 * moe_dim]down_proj:[num_experts, moe_dim, hidden_dim]- Optional FP8 scales:
[num_experts, scaled_rows, scaled_cols]
Supported configurations:
- Gate/up formats: concatenated or interleaved.
- Activation functions: configurable (default: SiLU).
- Optional bias support for projections.
- Optional FP8 quantization with block scaling.
- Optional shared experts.
For example:
moe = StackedMoE(
devices=[device],
hidden_dim=4096,
num_experts=8,
num_experts_per_token=2,
moe_dim=14336,
gate_cls=MyGate,
)
# With FP8 quantization
moe = StackedMoE(
devices=[device],
hidden_dim=4096,
num_experts=8,
num_experts_per_token=2,
moe_dim=14336,
gate_cls=MyGate,
float8_config=float8_config,
)
# GptOss style with interleaved format and bias
moe = StackedMoE(
devices=[device],
hidden_dim=4096,
num_experts=8,
num_experts_per_token=2,
moe_dim=14336,
gate_cls=GptOssMoEGate,
gate_up_format=GateUpFormat.INTERLEAVED,
activation_fn=my_custom_activation,
has_bias=True,
)-
Parameters:
-
- devices (list[DeviceRef]) – A list of devices to use for the MoE.
- hidden_dim (int) – The dimension of the hidden state.
- num_experts (int) – The total number of experts.
- num_experts_per_token (int) – The number of experts per token (top-k).
- moe_dim (int) – The intermediate dimension of each expert.
- gate_cls (Callable[..., MoEGate]) – The model-specific gate implementation class.
- dtype (DType) – The data type of the MoE weights. Defaults to
DType.bfloat16. - gate_up_format (GateUpFormat) – The format of the combined gate/up weights. Defaults
to
GateUpFormat.CONCATENATED. - activation_fn (Callable[[TensorValue, TensorValue], TensorValue]) – The activation function taking
(gate, up)and returning the activated output. Defaults tosilu_activation(). - has_bias (bool) – Whether to include bias for projections. Defaults to
False. - has_shared_experts (bool) – Whether to use shared experts. Defaults to
False. - shared_experts_dim (int) – The dimension of the shared experts. Defaults to
0. - float8_config (Float8Config | None) – The configuration for FP8 quantization. Defaults to
None. - apply_router_weight_first (bool) – Whether to apply router weights before
expert computation. Defaults to
False. - is_sharding (bool) – Whether this instance is being created for sharding.
Set by
shard()to skip weight initialization for sharded instances. Defaults toFalse.
down_proj_transposed
property down_proj_transposed: TensorValue
The down weights transposed to [num_experts, out_features, in_features] layout.
down_scale_transposed
property down_scale_transposed: TensorValue
The down scales transposed for FP8 matmul.
gate_up_proj_transposed
property gate_up_proj_transposed: TensorValue
The gate/up weights transposed to [num_experts, out_features, in_features] layout.
gate_up_scale_transposed
property gate_up_scale_transposed: TensorValue
The gate/up scales transposed for FP8 matmul.
shard()
shard(devices)
Creates sharded views of this MoE module across multiple devices.
-
Parameters:
-
devices (Iterable[DeviceRef]) – The devices to place the shards on.
-
Returns:
-
A sequence of sharded instances, one for each device.
-
Raises:
-
ValueError – If no sharding strategy has been set.
-
Return type:
sharding_strategy
property sharding_strategy: ShardingStrategy | None
The sharding strategy for this module.
silu_activation()
max.nn.legacy.moe.silu_activation(gate, up)
Computes a SiLU gated activation as up * silu(gate).
This is the default activation used by most MoE implementations including Llama4 and Qwen3VL.
-
Parameters:
-
- gate (TensorValue) – The gate projection tensor.
- up (TensorValue) – The up projection tensor.
-
Returns:
-
The element-wise product of
upandsilu(gate). -
Return type:
silu_gate()
max.nn.legacy.moe.silu_gate(gate_up_projs, moe_dim)
Applies SiLU-gated activation: silu(gate) * up.
-
Parameters:
-
- gate_up_projs (TensorValue)
- moe_dim (int)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!