Python module
linear
Multi-layer Perceptron.
ColumnParallelLinear
class max.nn.legacy.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)
A Linear layer where the weight and bias are sharded onto multiple devices.
This layer first computes for each device i in [0,…, num_devices]:
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+The values are then collected using an Allgather op, producing the same output tensor on each device:
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+Example usage:
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)Initializes the column-parallel linear layer.
-
Parameters:
-
- in_dim (int) – The dimensionality of the input space.
- out_dim (int) – The dimensionality of the output space.
- dtype (DType) – The
DTypefor both weights and bias. - devices (Sequence[DeviceRef]) – The target
DeviceRefdevices for computation. Weights remain on CPU until sharded and moved to device during computation. - tied_weight (Weight | None) – Optional
Weightto tie with this layer. - **kwargs – Additional keyword arguments passed to the Linear initializer.
DistributedGemmConfig
class max.nn.legacy.linear.DistributedGemmConfig(enable_matmul_allreduce)
Configure how distributed GEMM is executed.
Configuration for distributed General Matrix Multiply operations.
-
Parameters:
-
enable_matmul_allreduce (bool)
enable_matmul_allreduce
enable_matmul_allreduce: bool
If True, use the matmul + all_reduce kernel.
generate()
static generate()
Returns the default DistributedGemmConfig.
-
Returns:
-
A
DistributedGemmConfiginstance with default settings. -
Return type:
-
DistributedGemmConfig | None
GPTQLinear
class max.nn.legacy.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)
A Linear layer for GPTQ encoding.
Initializes the linear layer with weights and optional bias with GPTQ quantization.
Initializes the layer for GPTQ quantized linear transformations.
-
Parameters:
-
- in_dim (int) – The dimensionality of the input space.
- out_dim (int) – The dimensionality of the output space.
- dtype (DType) – The
DTypefor both weights and bias. - device (DeviceRef) – The target
DeviceReffor computation. Weights remain on CPU until moved during computation. - has_bias (bool) – When
True, adds a bias vector to the layer. Defaults toFalse. - quantization_encoding (QuantizationEncoding | None) – The
QuantizationEncodingof the weights. - quantization_config (QuantizationConfig | None) – Extra
QuantizationConfigfor the weight quantization. - float8_config (Float8Config | None) –
Float8Configfor float8 quantization (not supported).
Linear
class max.nn.legacy.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None, is_sharding=False)
Applies a linear transformation to incoming data: .
This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device.
Example:
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)Initializes the linear layer with weights and optional bias.
-
Parameters:
-
- in_dim (int) – The dimensionality of the input space.
- out_dim (int) – The dimensionality of the output space.
- dtype (DType) – The
DTypefor both weights and bias. - device (DeviceRef) – The target
DeviceReffor computation. Weights remain on CPU until moved during computation. - name (str | None) – Base name for weights (appended with
.weightand.biasif applicable). - has_bias (bool) – When
True, adds a bias vector to the layer. Defaults toFalse. - quantization_encoding (QuantizationEncoding | None) –
QuantizationEncodingfor the weights. - float8_config (Float8Config | None) –
Float8Configfor float8 quantization. - clip_weight (float | None) – Optional weight clipping threshold.
- is_sharding (bool) – Disable child layer creation during sharding.
bias
The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.
device
device: DeviceRef
The device where matrix operations are performed.
input_scale
The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.
shard()
shard(devices)
Creates sharded views of this Linear layer across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the weight sharding strategy.
weight
weight: Weight
The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.
weight_scale
The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.
MLP
class max.nn.legacy.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.legacy.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None, dist_gemm_config=None, is_sharding=False)
Simple multi-layer perceptron composed of three Linear layers.
Defaults to SiLU activation function.
Initializes the MLP layer.
-
Parameters:
-
-
dtype (DType) –
DTypeto use for the layer weights, which should match the input dtype. -
quantization_encoding (QuantizationEncoding | None) –
QuantizationEncodingof the layer weights. -
hidden_dim (int) – The last dimension of the layer input.
-
feed_forward_length (int) – Size of dimension used to project the inputs.
-
linear_cls (Callable[..., Linear]) –
Linearclass to use to create the projection layers. -
devices (Sequence[DeviceRef]) –
DeviceRefdevices to run theMLPlayer. If multiple are provided, the first device is used instead. UseDistributedMLPto use all devices. -
has_bias (bool) – Whether to include bias terms in the linear layers.
-
activation_function (str) –
Activation function to use. Options are:
silugelugelu_tanhrelutanhsigmoid
-
float8_config (Float8Config | None) –
Float8Configfor float8 quantization. -
dist_gemm_config (DistributedGemmConfig | None) –
DistributedGemmConfigfor distributed GEMM configuration. -
is_sharding (bool) – Disable child layer creation during sharding.
-
shard()
shard(devices)
Creates sharded views of this MLP across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the MLP sharding strategy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!