Python class
MLP
MLP
class max.nn.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.linear.Linear'>, has_bias=False, activation_function='silu', quant_config=None, is_sharding=False)
Simple multi-layer perceptron composed of three Linear layers.
When called, MLP accepts a TensorValueLike of shape
(..., hidden_dim) and returns a TensorValue of
the same shape (..., hidden_dim).
Initializes the MLP layer.
-
Parameters:
-
-
dtype (DType) –
DTypeto use for the layer weights, which should match the input dtype. -
quantization_encoding (QuantizationEncoding | None) –
QuantizationEncodingof the layer weights. -
hidden_dim (int) – The last dimension of the layer input.
-
feed_forward_length (int) – Size of dimension used to project the inputs.
-
linear_cls (Callable[..., Linear]) –
Linearclass to use to create the projection layers. -
devices (Sequence[DeviceRef]) –
DeviceRefdevices to run theMLPlayer. -
has_bias (bool) – Whether to include bias terms in the linear layers.
-
activation_function (str) –
Activation function to use. Options are:
silugelugelu_tanhrelutanhsigmoid
-
quant_config (QuantConfig | None) –
QuantConfigfor scaled quantization. -
is_sharding (bool) – Disable child layer creation during sharding.
-
shard()
shard(devices)
Creates sharded views of this MLP across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the MLP sharding strategy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!