Python module
moe
Mixture of Experts (MoE) module.
Fp8Strategy
class max.nn.legacy.moe.Fp8Strategy(config, dtype)
FP8 quantization for MoE.
-
Parameters:
-
- config (Float8Config)
- dtype (DType)
fused_silu_quantize()
fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())
Applies fused SiLU gate and returns quantized activations.
-
Parameters:
-
- gate_up_projs (TensorValue)
- input_scales (TensorValue | None)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
grouped_matmul()
grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())
Runs grouped FP8 matmul for the routed experts.
-
Parameters:
-
- weight (TensorValue)
- weight_scales (TensorValue)
- expert_scales (TensorValue | None)
- tokens_padded_per_expert (bool)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
prepare_weight_scales()
prepare_weight_scales(gate_up, down, device)
Passes FP8 weight scales through without reformatting.
-
Parameters:
-
- gate_up (TensorValue)
- down (TensorValue)
- device (DeviceRef)
-
Return type:
quantize()
quantize(tensor, group_size, input_scale=None)
Quantizes activations to FP8 and returns (quantized, scales).
-
Parameters:
-
- tensor (TensorValue)
- group_size (int)
- input_scale (TensorValue | None)
-
Return type:
MoE
class max.nn.legacy.moe.MoE(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.legacy.moe.moe.MoEGate'>, mlp_cls=<class 'max.nn.legacy.linear.MLP'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False, ep_batch_manager=None, float8_config=None, is_sharding=False)
Implementation of Mixture of Experts (MoE).
-
Parameters:
-
- devices (list[DeviceRef])
- hidden_dim (int)
- num_experts (int)
- num_experts_per_token (int)
- moe_dim (int)
- gate_cls (Callable[..., MoEGate])
- mlp_cls (Callable[..., MLP])
- has_shared_experts (bool)
- shared_experts_dim (int)
- ep_size (int)
- dtype (DType)
- apply_router_weight_first (bool)
- ep_batch_manager (EPBatchManager | None)
- float8_config (Float8Config | None)
- is_sharding (bool)
down_proj
property down_proj: TensorValue
ep_batch_manager
property ep_batch_manager: EPBatchManager
Get the expert parallel batch manager.
experts
experts: LayerList
The list of experts.
gate_up_proj
property gate_up_proj: TensorValue
shard()
shard(devices)
Create sharded views of this MoE module across multiple devices.
shard_devices
The list of devices the MoE layer was sharded to.
shard_index
shard_index: int = 0
The index of the current shard (if the MoE layer was sharded).
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the sharding strategy for the module.
MoEGate
class max.nn.legacy.moe.MoEGate(devices, hidden_dim, num_experts, num_experts_per_token, dtype, is_sharding=False, linear_cls=<class 'max.nn.legacy.linear.Linear'>)
Gate module for MoE.
-
Parameters:
shard()
shard(devices)
Create sharded views of this MoEGate module across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the sharding strategy for the module.
MoEQuantized
class max.nn.legacy.moe.MoEQuantized(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.legacy.moe.moe.MoEGate'>, mlp_cls=<class 'max.nn.legacy.linear.MLP'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False, ep_batch_manager=None, float8_config=None, is_sharding=False)
Mixture of Experts with FP8 or NVFP4 quantization.
-
Parameters:
-
- devices (list[DeviceRef])
- hidden_dim (int)
- num_experts (int)
- num_experts_per_token (int)
- moe_dim (int)
- gate_cls (Callable[..., MoEGate])
- mlp_cls (Callable[..., MLP])
- has_shared_experts (bool)
- shared_experts_dim (int)
- ep_size (int)
- dtype (DType)
- apply_router_weight_first (bool)
- ep_batch_manager (EPBatchManager | None)
- float8_config (Float8Config | None)
- is_sharding (bool)
down_proj_scales
property down_proj_scales: TensorValue
Returns stacked down-projection weight scales.
gate_up_proj_scales
property gate_up_proj_scales: TensorValue
Returns stacked gate/up weight scales for grouped matmul.
Nvfp4Scales
class max.nn.legacy.moe.Nvfp4Scales(gate_up_input, down_input, gate_up_expert, down_expert)
Bundled scales for NVFP4 quantization.
-
Parameters:
-
- gate_up_input (TensorValue)
- down_input (TensorValue)
- gate_up_expert (TensorValue)
- down_expert (TensorValue)
down_expert
down_expert: TensorValue
down_input
down_input: TensorValue
gate_up_expert
gate_up_expert: TensorValue
gate_up_input
gate_up_input: TensorValue
Nvfp4Strategy
class max.nn.legacy.moe.Nvfp4Strategy(config, dtype)
NVFP4 quantization for MoE.
-
Parameters:
-
- config (Float8Config)
- dtype (DType)
fused_silu_quantize()
fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())
Applies SiLU gate then NVFP4 quantizes the result.
-
Parameters:
-
- gate_up_projs (TensorValue)
- input_scales (TensorValue | None)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
grouped_matmul()
grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())
Runs grouped NVFP4 matmul with per-expert scales.
-
Parameters:
-
- weight (TensorValue)
- weight_scales (TensorValue)
- expert_scales (TensorValue | None)
- tokens_padded_per_expert (bool)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
prepare_weight_scales()
prepare_weight_scales(gate_up, down, device)
Interleaves NVFP4 block scales for kernel layout.
-
Parameters:
-
- gate_up (TensorValue)
- down (TensorValue)
- device (DeviceRef)
-
Return type:
quantize()
quantize(tensor, group_size, input_scale=None)
Quantizes activations to NVFP4 and returns (quantized, scales).
-
Parameters:
-
- tensor (TensorValue)
- group_size (int)
- input_scale (TensorValue | None)
-
Return type:
QuantStrategy
class max.nn.legacy.moe.QuantStrategy(*args, **kwargs)
Quantization strategy for MoE layers.
fused_silu_quantize()
fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())
Applies gating and quantizes activations for the down proj.
-
Parameters:
-
- gate_up_projs (TensorValue)
- input_scales (TensorValue | None)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
grouped_matmul()
grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())
Runs grouped matmul for routed experts.
-
Parameters:
-
- weight (TensorValue)
- weight_scales (TensorValue)
- expert_scales (TensorValue | None)
- tokens_padded_per_expert (bool)
- expert_inputs (tuple[TensorValue, ...])
-
Return type:
prepare_weight_scales()
prepare_weight_scales(gate_up, down, device)
Prepares weight scales for kernel consumption.
-
Parameters:
-
- gate_up (TensorValue)
- down (TensorValue)
- device (DeviceRef)
-
Return type:
quantize()
quantize(tensor, group_size, input_scale=None)
Quantizes activations and returns (quantized, scales).
-
Parameters:
-
- tensor (TensorValue)
- group_size (int)
- input_scale (TensorValue | None)
-
Return type:
silu_gate()
max.nn.legacy.moe.silu_gate(gate_up_projs, moe_dim)
Applies SiLU-gated activation: silu(gate) * up.
-
Parameters:
-
- gate_up_projs (TensorValue)
- moe_dim (int)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!