Skip to main content

Python module

moe

Mixture of Experts (MoE) module.

Fp8Strategy

class max.nn.legacy.moe.Fp8Strategy(config, dtype)

FP8 quantization for MoE.

Parameters:

fused_silu_quantize()

fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())

Applies fused SiLU gate and returns quantized activations.

Parameters:

Return type:

tuple[TensorValue, TensorValue]

grouped_matmul()

grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())

Runs grouped FP8 matmul for the routed experts.

Parameters:

Return type:

TensorValue

prepare_weight_scales()

prepare_weight_scales(gate_up, down, device)

Passes FP8 weight scales through without reformatting.

Parameters:

Return type:

tuple[TensorValue, TensorValue]

quantize()

quantize(tensor, group_size, input_scale=None)

Quantizes activations to FP8 and returns (quantized, scales).

Parameters:

Return type:

tuple[TensorValue, TensorValue]

GateUpFormat

class max.nn.legacy.moe.GateUpFormat(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies the format of the combined gate/up projection weights.

CONCATENATED

CONCATENATED = 'concatenated'

Gate and up projections concatenated as [gate | up].

Stored as [num_experts, hidden_dim, 2 * moe_dim]. Split at moe_dim: gate = output[:, :moe_dim], up = output[:, moe_dim:]. Used by Llama4 and Qwen3VL.

INTERLEAVED

INTERLEAVED = 'interleaved'

Gate and up projections interleaved as [g0, u0, g1, u1, ...].

Stored as [num_experts, hidden_dim, 2 * moe_dim]. Split with stride: gate = output[:, 0::2], up = output[:, 1::2]. Used by GptOss.

MoE

class max.nn.legacy.moe.MoE(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.legacy.moe.moe.MoEGate'>, mlp_cls=<class 'max.nn.legacy.linear.MLP'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False, ep_batch_manager=None, float8_config=None, is_sharding=False)

Implementation of Mixture of Experts (MoE).

Parameters:

down_proj

property down_proj: TensorValue

ep_batch_manager

property ep_batch_manager: EPBatchManager

Get the expert parallel batch manager.

experts

experts: LayerList

The list of experts.

gate_up_proj

property gate_up_proj: TensorValue

shard()

shard(devices)

Create sharded views of this MoE module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MoE instances, one for each device.

Return type:

list[Self]

shard_devices

shard_devices: list[DeviceRef] = []

The list of devices the MoE layer was sharded to.

shard_index

shard_index: int = 0

The index of the current shard (if the MoE layer was sharded).

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the sharding strategy for the module.

MoEGate

class max.nn.legacy.moe.MoEGate(devices, hidden_dim, num_experts, num_experts_per_token, dtype, is_sharding=False, linear_cls=<class 'max.nn.legacy.linear.Linear'>)

Gate module for MoE.

Parameters:

shard()

shard(devices)

Create sharded views of this MoEGate module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MoEGate instances, one for each device.

Return type:

Sequence[MoEGate]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the sharding strategy for the module.

MoEQuantized

class max.nn.legacy.moe.MoEQuantized(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.legacy.moe.moe.MoEGate'>, mlp_cls=<class 'max.nn.legacy.linear.MLP'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False, ep_batch_manager=None, float8_config=None, is_sharding=False)

Mixture of Experts with FP8 or NVFP4 quantization.

Parameters:

down_proj_scales

property down_proj_scales: TensorValue

Returns stacked down-projection weight scales.

gate_up_proj_scales

property gate_up_proj_scales: TensorValue

Returns stacked gate/up weight scales for grouped matmul.

Nvfp4Scales

class max.nn.legacy.moe.Nvfp4Scales(gate_up_input, down_input, gate_up_expert, down_expert)

Bundled scales for NVFP4 quantization.

Parameters:

down_expert

down_expert: TensorValue

down_input

down_input: TensorValue

gate_up_expert

gate_up_expert: TensorValue

gate_up_input

gate_up_input: TensorValue

Nvfp4Strategy

class max.nn.legacy.moe.Nvfp4Strategy(config, dtype)

NVFP4 quantization for MoE.

Parameters:

fused_silu_quantize()

fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())

Applies SiLU gate then NVFP4 quantizes the result.

Parameters:

Return type:

tuple[TensorValue, TensorValue]

grouped_matmul()

grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())

Runs grouped NVFP4 matmul with per-expert scales.

Parameters:

Return type:

TensorValue

prepare_weight_scales()

prepare_weight_scales(gate_up, down, device)

Interleaves NVFP4 block scales for kernel layout.

Parameters:

Return type:

tuple[TensorValue, TensorValue]

quantize()

quantize(tensor, group_size, input_scale=None)

Quantizes activations to NVFP4 and returns (quantized, scales).

Parameters:

Return type:

tuple[TensorValue, TensorValue]

QuantStrategy

class max.nn.legacy.moe.QuantStrategy(*args, **kwargs)

Quantization strategy for MoE layers.

fused_silu_quantize()

fused_silu_quantize(gate_up_projs, input_scales=None, expert_inputs=())

Applies gating and quantizes activations for the down proj.

Parameters:

Return type:

tuple[TensorValue, TensorValue]

grouped_matmul()

grouped_matmul(weight, weight_scales, expert_scales=None, tokens_padded_per_expert=False, expert_inputs=())

Runs grouped matmul for routed experts.

Parameters:

Return type:

TensorValue

prepare_weight_scales()

prepare_weight_scales(gate_up, down, device)

Prepares weight scales for kernel consumption.

Parameters:

Return type:

tuple[TensorValue, TensorValue]

quantize()

quantize(tensor, group_size, input_scale=None)

Quantizes activations and returns (quantized, scales).

Parameters:

Return type:

tuple[TensorValue, TensorValue]

StackedMoE

class max.nn.legacy.moe.StackedMoE(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls, dtype=bfloat16, gate_up_format=GateUpFormat.CONCATENATED, activation_fn=<function silu_activation>, has_bias=False, has_shared_experts=False, shared_experts_dim=0, float8_config=None, apply_router_weight_first=False, is_sharding=False)

Stacked Mixture of Experts layer with configurable components.

This class consolidates MoE implementations from multiple architectures (Llama4, Qwen3VL, GptOss) into a single base layer. All expert weights are stored in stacked format rather than as individual MLP experts.

Weight tensor shapes:

  • gate_up_proj: [num_experts, hidden_dim, 2 * moe_dim]
  • down_proj: [num_experts, moe_dim, hidden_dim]
  • Optional FP8 scales: [num_experts, scaled_rows, scaled_cols]

Supported configurations:

  • Gate/up formats: concatenated or interleaved.
  • Activation functions: configurable (default: SiLU).
  • Optional bias support for projections.
  • Optional FP8 quantization with block scaling.
  • Optional shared experts.

For example:

moe = StackedMoE(
    devices=[device],
    hidden_dim=4096,
    num_experts=8,
    num_experts_per_token=2,
    moe_dim=14336,
    gate_cls=MyGate,
)

# With FP8 quantization
moe = StackedMoE(
    devices=[device],
    hidden_dim=4096,
    num_experts=8,
    num_experts_per_token=2,
    moe_dim=14336,
    gate_cls=MyGate,
    float8_config=float8_config,
)

# GptOss style with interleaved format and bias
moe = StackedMoE(
    devices=[device],
    hidden_dim=4096,
    num_experts=8,
    num_experts_per_token=2,
    moe_dim=14336,
    gate_cls=GptOssMoEGate,
    gate_up_format=GateUpFormat.INTERLEAVED,
    activation_fn=my_custom_activation,
    has_bias=True,
)

Parameters:

  • devices (list[DeviceRef]) – A list of devices to use for the MoE.
  • hidden_dim (int) – The dimension of the hidden state.
  • num_experts (int) – The total number of experts.
  • num_experts_per_token (int) – The number of experts per token (top-k).
  • moe_dim (int) – The intermediate dimension of each expert.
  • gate_cls (Callable[..., MoEGate]) – The model-specific gate implementation class.
  • dtype (DType) – The data type of the MoE weights. Defaults to DType.bfloat16.
  • gate_up_format (GateUpFormat) – The format of the combined gate/up weights. Defaults to GateUpFormat.CONCATENATED.
  • activation_fn (Callable[[TensorValue, TensorValue], TensorValue]) – The activation function taking (gate, up) and returning the activated output. Defaults to silu_activation().
  • has_bias (bool) – Whether to include bias for projections. Defaults to False.
  • has_shared_experts (bool) – Whether to use shared experts. Defaults to False.
  • shared_experts_dim (int) – The dimension of the shared experts. Defaults to 0.
  • float8_config (Float8Config | None) – The configuration for FP8 quantization. Defaults to None.
  • apply_router_weight_first (bool) – Whether to apply router weights before expert computation. Defaults to False.
  • is_sharding (bool) – Whether this instance is being created for sharding. Set by shard() to skip weight initialization for sharded instances. Defaults to False.

down_proj_transposed

property down_proj_transposed: TensorValue

The down weights transposed to [num_experts, out_features, in_features] layout.

down_scale_transposed

property down_scale_transposed: TensorValue

The down scales transposed for FP8 matmul.

gate_up_proj_transposed

property gate_up_proj_transposed: TensorValue

The gate/up weights transposed to [num_experts, out_features, in_features] layout.

gate_up_scale_transposed

property gate_up_scale_transposed: TensorValue

The gate/up scales transposed for FP8 matmul.

shard()

shard(devices)

Creates sharded views of this MoE module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – The devices to place the shards on.

Returns:

A sequence of sharded instances, one for each device.

Raises:

ValueError – If no sharding strategy has been set.

Return type:

Sequence[Self]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

The sharding strategy for this module.

silu_activation()

max.nn.legacy.moe.silu_activation(gate, up)

Computes a SiLU gated activation as up * silu(gate).

This is the default activation used by most MoE implementations including Llama4 and Qwen3VL.

Parameters:

Returns:

The element-wise product of up and silu(gate).

Return type:

TensorValue

silu_gate()

max.nn.legacy.moe.silu_gate(gate_up_projs, moe_dim)

Applies SiLU-gated activation: silu(gate) * up.

Parameters:

Return type:

TensorValue

Was this page helpful?