Skip to main content

Python module

linear

Multi-layer Perceptron.

ColumnParallelLinear

class max.nn.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
|     |       | W_0 |       | b_0 |       | y_0 | GPU0
|     |       +-----+       +-----+       +-----+
|     |       | W_1 |       | b_1 |       | y_1 | GPU1
|  x  |   @   +-----+   +   +-----+   =   +-----+
|     |       | W_2 |       | b_2 |       | y_2 | GPU2
|     |       +-----+       +-----+       +-----+
|     |       | W_3 |       | b_3 |       | y_3 | GPU3
+-----+       +-----+       +-----+       +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
| y_0 |  -  |  -  |  -  |                   | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  | y_1 |  -  |  -  |                   | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+  -- Allgather --> +-----+-----+-----+-----+
|  -  |  -  | y_2 |  -  |                   | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  |  -  |  -  | y_3 |                   | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
    in_dim,
    out_dim,
    DType.float32,
    devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • devices (Sequence[DeviceRef]) – The target DeviceRef devices for computation. Weights remain on CPU until sharded and moved to device during computation.
  • tied_weight (Weight | None) – Optional Weight to tie with this layer.

DistributedGemmConfig

class max.nn.linear.DistributedGemmConfig(enable_matmul_allreduce)

Configure how distributed GEMM is executed.

Parameters:

enable_matmul_allreduce (bool)

enable_matmul_allreduce

enable_matmul_allreduce: bool

If True, use the matmul + all_reduce kernel.

generate()

static generate()

Returns the default DistributedGemmConfig.

Returns:

A DistributedGemmConfig instance with default settings.

Return type:

DistributedGemmConfig | None

GPTQLinear

class max.nn.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)

A Linear layer for GPTQ encoding.

Initializes the linear layer with weights and optional bias with GPTQ quantization.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • device (DeviceRef) – The target DeviceRef for computation. Weights remain on CPU until moved during computation.
  • has_bias (bool) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None) – The QuantizationEncoding of the weights.
  • quantization_config (QuantizationConfig | None) – Extra QuantizationConfig for the weight quantization.
  • float8_config (Float8Config | None) – Float8Config for float8 quantization (not supported).

GPTQLinearV1

class max.nn.linear.GPTQLinearV1(weight, bias=None, quantization_encoding=None, quantization_config=None, perm_idx=None)

A Linear layer for GPTQ encoding.

Deprecated

Deprecated since version 25.5: Use GPTQLinear instead.

Parameters:

perm_idx

perm_idx: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None = None

Optional permutation indices for GPTQ quantization.

quantization_config

quantization_config: QuantizationConfig | None = None

The QuantizationConfig for GPTQ quantization.

Linear

class max.nn.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device.

Example:

linear_layer = Linear(
    in_dim=256,
    out_dim=128,
    dtype=DType.float32,
    device=DeviceRef.GPU(),
    name="linear",
    has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • device (DeviceRef) – The target DeviceRef for computation. Weights remain on CPU until moved during computation.
  • name (str | None) – Base name for weights (appended with .weight and .bias if applicable).
  • has_bias (bool) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding for the weights.
  • float8_config (Float8Config | None) – Float8Config for float8 quantization.
  • clip_weight (float | None) – Optional weight clipping threshold.

bias

bias: Weight | None = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.

device

device: DeviceRef

The device where matrix operations are performed.

input_scale

input_scale: Weight | None = None

The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.

shard()

shard(devices)

Creates sharded views of this Linear layer across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of DeviceRef devices to place the shards on.

Returns:

List of sharded Linear instances, one for each device.

Return type:

list[Linear]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the weight sharding strategy.

weight

weight: Weight

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.

weight_scale

weight_scale: Weight | None = None

The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.

LinearV1

class max.nn.linear.LinearV1(weight, bias=None)

A unified linear layer that delegates to either regular or quantized implementation.

Deprecated

Deprecated since version 25.5: Use Linear instead.

Parameters:

bias

bias: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None = None

Optional bias tensor for the linear transformation.

create()

classmethod create(dtype, quantization_encoding, in_features, out_features, weights, bias=None, quantization_config=None)

Factory method to create a LinearV1 layer with appropriate implementation.

Parameters:

  • dtype (DType) – The DType for the layer.
  • quantization_encoding (QuantizationEncoding | None) – The QuantizationEncoding for the weights.
  • in_features (int) – The input feature dimension.
  • out_features (int) – The output feature dimension.
  • weights (Weights | Weight) – The Weights or Weight object for the layer.
  • bias (Weights | Weight | None) – Optional Weights or Weight object for bias.
  • quantization_config (QuantizationConfig | None) – Optional QuantizationConfig for quantization.

Returns:

A LinearV1 instance.

Return type:

LinearV1

weight

weight: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]]

The weight tensor for the linear transformation.

MLP

class max.nn.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None, dist_gemm_config=None)

Simple multi-layer perceptron composed of three Linear layers. Defaults to SiLU activation function.

Parameters:

  • dtype (DType) – DType to use for the layer weights, which should match the input dtype.

  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding of the layer weights.

  • hidden_dim (int) – The last dimension of the layer input.

  • feed_forward_length (int) – Size of dimension used to project the inputs.

  • linear_cls (Callable[..., Linear]) – Linear class to use to create the projection layers.

  • devices (Sequence[DeviceRef]) – DeviceRef devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.

  • has_bias (bool) – Whether to include bias terms in the linear layers.

  • activation_function (str) –

    Activation function to use. Options are:

    • silu
    • gelu
    • gelu_tanh
    • relu
    • tanh
    • sigmoid
  • float8_config (Float8Config | None) – Float8Config for float8 quantization.

  • dist_gemm_config (DistributedGemmConfig | None) – DistributedGemmConfig for distributed GEMM configuration.

shard()

shard(devices)

Creates sharded views of this MLP across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MLP instances, one for each device.

Return type:

list[MLP]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the MLP sharding strategy.

MLPV1

class max.nn.linear.MLPV1(gate_proj, down_proj, up_proj)

Simple multi-layer perceptron composed of three LinearV1 layers. Uses SiLU activation function.

Deprecated

Deprecated since version 25.5: Use MLP instead.

Parameters:

down_proj

down_proj: LinearV1

The down projection LinearV1 layer.

gate_proj

gate_proj: LinearV1

The gate projection LinearV1 layer.

up_proj

up_proj: LinearV1

The up projection LinearV1 layer.

QLinearV1

class max.nn.linear.QLinearV1(weight, bias=None, quantization_encoding=None)

A quantized fully connected layer.

Deprecated

Deprecated since version 25.5: Use Linear instead.

Parameters:

quantization_encoding

quantization_encoding: QuantizationEncoding | None = None

The QuantizationEncoding for the quantized weights.

Was this page helpful?