Python module
linear
Multi-layer Perceptron.
ColumnParallelLinear
class max.nn.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)
A Linear
layer where the weight and bias are sharded onto multiple devices.
This layer first computes for each device i in [0,…, num_devices]:
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
The values are then collected using an Allgather op, producing the same output tensor on each device:
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
Example usage:
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
-
Parameters:
-
- in_dim (int) – The dimensionality of the input space.
- out_dim (int) – The dimensionality of the output space.
- dtype (DType) – The
DType
for both weights and bias. - devices (Sequence[DeviceRef]) – The target
DeviceRef
devices for computation. Weights remain on CPU until sharded and moved to device during computation. - tied_weight (Weight | None) – Optional
Weight
to tie with this layer.
DistributedGemmConfig
class max.nn.linear.DistributedGemmConfig(enable_matmul_allreduce)
Configure how distributed GEMM is executed.
-
Parameters:
-
enable_matmul_allreduce (bool)
enable_matmul_allreduce
enable_matmul_allreduce: bool
If True
, use the matmul + all_reduce kernel.
generate()
static generate()
Returns the default DistributedGemmConfig
.
-
Returns:
-
A
DistributedGemmConfig
instance with default settings. -
Return type:
-
DistributedGemmConfig | None
GPTQLinear
class max.nn.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)
A Linear
layer for GPTQ encoding.
Initializes the linear layer with weights and optional bias with GPTQ quantization.
-
Parameters:
-
- in_dim (int) – The dimensionality of the input space.
- out_dim (int) – The dimensionality of the output space.
- dtype (DType) – The
DType
for both weights and bias. - device (DeviceRef) – The target
DeviceRef
for computation. Weights remain on CPU until moved during computation. - has_bias (bool) – When
True
, adds a bias vector to the layer. Defaults toFalse
. - quantization_encoding (QuantizationEncoding | None) – The
QuantizationEncoding
of the weights. - quantization_config (QuantizationConfig | None) – Extra
QuantizationConfig
for the weight quantization. - float8_config (Float8Config | None) –
Float8Config
for float8 quantization (not supported).
GPTQLinearV1
class max.nn.linear.GPTQLinearV1(weight, bias=None, quantization_encoding=None, quantization_config=None, perm_idx=None)
A Linear
layer for GPTQ encoding.
Deprecated
Deprecated since version 25.5: Use GPTQLinear
instead.
-
Parameters:
-
- weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]])
- bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None)
- quantization_encoding (QuantizationEncoding | None)
- quantization_config (QuantizationConfig | None)
- perm_idx (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None)
perm_idx
perm_idx: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None = None
Optional permutation indices for GPTQ quantization.
quantization_config
quantization_config: QuantizationConfig | None = None
The QuantizationConfig
for GPTQ quantization.
Linear
class max.nn.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None)
Applies a linear transformation to incoming data: .
This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device.
Example:
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)
Initializes the linear layer with weights and optional bias.
-
Parameters:
-
- in_dim (int) – The dimensionality of the input space.
- out_dim (int) – The dimensionality of the output space.
- dtype (DType) – The
DType
for both weights and bias. - device (DeviceRef) – The target
DeviceRef
for computation. Weights remain on CPU until moved during computation. - name (str | None) – Base name for weights (appended with
.weight
and.bias
if applicable). - has_bias (bool) – When
True
, adds a bias vector to the layer. Defaults toFalse
. - quantization_encoding (QuantizationEncoding | None) –
QuantizationEncoding
for the weights. - float8_config (Float8Config | None) –
Float8Config
for float8 quantization. - clip_weight (float | None) – Optional weight clipping threshold.
bias
The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.
device
device: DeviceRef
The device where matrix operations are performed.
input_scale
The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.
shard()
shard(devices)
Creates sharded views of this Linear layer across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the weight sharding strategy.
weight
weight: Weight
The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.
weight_scale
The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.
LinearV1
class max.nn.linear.LinearV1(weight, bias=None)
A unified linear layer that delegates to either regular or quantized implementation.
Deprecated
Deprecated since version 25.5: Use Linear
instead.
-
Parameters:
-
- weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]])
- bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None)
bias
bias: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None = None
Optional bias tensor for the linear transformation.
create()
classmethod create(dtype, quantization_encoding, in_features, out_features, weights, bias=None, quantization_config=None)
Factory method to create a LinearV1
layer with appropriate implementation.
-
Parameters:
-
- dtype (DType) – The
DType
for the layer. - quantization_encoding (QuantizationEncoding | None) – The
QuantizationEncoding
for the weights. - in_features (int) – The input feature dimension.
- out_features (int) – The output feature dimension.
- weights (Weights | Weight) – The
Weights
orWeight
object for the layer. - bias (Weights | Weight | None) – Optional
Weights
orWeight
object for bias. - quantization_config (QuantizationConfig | None) – Optional
QuantizationConfig
for quantization.
- dtype (DType) – The
-
Returns:
-
A
LinearV1
instance. -
Return type:
weight
weight: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]]
The weight tensor for the linear transformation.
MLP
class max.nn.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None, dist_gemm_config=None)
Simple multi-layer perceptron composed of three Linear
layers.
Defaults to SiLU activation function.
-
Parameters:
-
-
dtype (DType) –
DType
to use for the layer weights, which should match the input dtype. -
quantization_encoding (QuantizationEncoding | None) –
QuantizationEncoding
of the layer weights. -
hidden_dim (int) – The last dimension of the layer input.
-
feed_forward_length (int) – Size of dimension used to project the inputs.
-
linear_cls (Callable[..., Linear]) –
Linear
class to use to create the projection layers. -
devices (Sequence[DeviceRef]) –
DeviceRef
devices to run theMLP
layer. If multiple are provided, the first device is used instead. UseDistributedMLP
to use all devices. -
has_bias (bool) – Whether to include bias terms in the linear layers.
-
activation_function (str) –
Activation function to use. Options are:
silu
gelu
gelu_tanh
relu
tanh
sigmoid
-
float8_config (Float8Config | None) –
Float8Config
for float8 quantization. -
dist_gemm_config (DistributedGemmConfig | None) –
DistributedGemmConfig
for distributed GEMM configuration.
-
shard()
shard(devices)
Creates sharded views of this MLP across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the MLP sharding strategy.
MLPV1
class max.nn.linear.MLPV1(gate_proj, down_proj, up_proj)
Simple multi-layer perceptron composed of three LinearV1
layers.
Uses SiLU activation function.
Deprecated
Deprecated since version 25.5: Use MLP
instead.
down_proj
down_proj: LinearV1
The down projection LinearV1
layer.
gate_proj
gate_proj: LinearV1
The gate projection LinearV1
layer.
up_proj
up_proj: LinearV1
The up projection LinearV1
layer.
QLinearV1
class max.nn.linear.QLinearV1(weight, bias=None, quantization_encoding=None)
A quantized fully connected layer.
Deprecated
Deprecated since version 25.5: Use Linear
instead.
-
Parameters:
-
- weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]])
- bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None)
- quantization_encoding (QuantizationEncoding | None)
quantization_encoding
quantization_encoding: QuantizationEncoding | None = None
The QuantizationEncoding
for the quantized weights.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!