Skip to main content

Python class

MLP

MLP

class max.nn.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.linear.Linear'>, has_bias=False, activation_function='silu', quant_config=None, is_sharding=False)

source

Bases: Module, Shardable

Simple multi-layer perceptron composed of three Linear layers.

When called, MLP accepts a TensorValueLike of shape (..., hidden_dim) and returns a TensorValue of the same shape (..., hidden_dim).

Initializes the MLP layer.

Parameters:

  • dtype (DType) – DType to use for the layer weights, which should match the input dtype.

  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding of the layer weights.

  • hidden_dim (int) – The last dimension of the layer input.

  • feed_forward_length (int) – Size of dimension used to project the inputs.

  • linear_cls (Callable[..., Linear]) – Linear class to use to create the projection layers.

  • devices (Sequence[DeviceRef]) – DeviceRef devices to run the MLP layer.

  • has_bias (bool) – Whether to include bias terms in the linear layers.

  • activation_function (str) –

    Activation function to use. Options are:

    • silu
    • gelu
    • gelu_tanh
    • relu
    • tanh
    • sigmoid
  • quant_config (QuantConfig | None) – QuantConfig for scaled quantization.

  • is_sharding (bool) – Disable child layer creation during sharding.

shard()

shard(devices)

source

Creates sharded views of this MLP across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MLP instances, one for each device.

Return type:

list[MLP]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the MLP sharding strategy.