Skip to main content

Python module

linear

Multi-layer Perceptron.

ColumnParallelLinear

class max.nn.legacy.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
|     |       | W_0 |       | b_0 |       | y_0 | GPU0
|     |       +-----+       +-----+       +-----+
|     |       | W_1 |       | b_1 |       | y_1 | GPU1
|  x  |   @   +-----+   +   +-----+   =   +-----+
|     |       | W_2 |       | b_2 |       | y_2 | GPU2
|     |       +-----+       +-----+       +-----+
|     |       | W_3 |       | b_3 |       | y_3 | GPU3
+-----+       +-----+       +-----+       +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
| y_0 |  -  |  -  |  -  |                   | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  | y_1 |  -  |  -  |                   | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+  -- Allgather --> +-----+-----+-----+-----+
|  -  |  -  | y_2 |  -  |                   | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  |  -  |  -  | y_3 |                   | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
    in_dim,
    out_dim,
    DType.float32,
    devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Initializes the column-parallel linear layer.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • devices (Sequence[DeviceRef]) – The target DeviceRef devices for computation. Weights remain on CPU until sharded and moved to device during computation.
  • tied_weight (Weight | None) – Optional Weight to tie with this layer.
  • **kwargs – Additional keyword arguments passed to the Linear initializer.

DistributedGemmConfig

class max.nn.legacy.linear.DistributedGemmConfig(enable_matmul_allreduce)

Configure how distributed GEMM is executed.

Configuration for distributed General Matrix Multiply operations.

Parameters:

enable_matmul_allreduce (bool)

enable_matmul_allreduce

enable_matmul_allreduce: bool

If True, use the matmul + all_reduce kernel.

generate()

static generate()

Returns the default DistributedGemmConfig.

Returns:

A DistributedGemmConfig instance with default settings.

Return type:

DistributedGemmConfig | None

GPTQLinear

class max.nn.legacy.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)

A Linear layer for GPTQ encoding.

Initializes the linear layer with weights and optional bias with GPTQ quantization.

Initializes the layer for GPTQ quantized linear transformations.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • device (DeviceRef) – The target DeviceRef for computation. Weights remain on CPU until moved during computation.
  • has_bias (bool) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None) – The QuantizationEncoding of the weights.
  • quantization_config (QuantizationConfig | None) – Extra QuantizationConfig for the weight quantization.
  • float8_config (Float8Config | None) – Float8Config for float8 quantization (not supported).

Linear

class max.nn.legacy.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None, is_sharding=False)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device.

Example:

linear_layer = Linear(
    in_dim=256,
    out_dim=128,
    dtype=DType.float32,
    device=DeviceRef.GPU(),
    name="linear",
    has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • device (DeviceRef) – The target DeviceRef for computation. Weights remain on CPU until moved during computation.
  • name (str | None) – Base name for weights (appended with .weight and .bias if applicable).
  • has_bias (bool) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding for the weights.
  • float8_config (Float8Config | None) – Float8Config for float8 quantization.
  • clip_weight (float | None) – Optional weight clipping threshold.
  • is_sharding (bool) – Disable child layer creation during sharding.

bias

bias: Weight | None = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.

device

device: DeviceRef

The device where matrix operations are performed.

input_scale

input_scale: Weight | None = None

The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.

shard()

shard(devices)

Creates sharded views of this Linear layer across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of DeviceRef devices to place the shards on.

Returns:

List of sharded Linear instances, one for each device.

Return type:

list[Linear]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the weight sharding strategy.

weight

weight: Weight

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.

weight_scale

weight_scale: Weight | None = None

The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.

MLP

class max.nn.legacy.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.legacy.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None, dist_gemm_config=None, is_sharding=False)

Simple multi-layer perceptron composed of three Linear layers.

Defaults to SiLU activation function.

Initializes the MLP layer.

Parameters:

  • dtype (DType) – DType to use for the layer weights, which should match the input dtype.

  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding of the layer weights.

  • hidden_dim (int) – The last dimension of the layer input.

  • feed_forward_length (int) – Size of dimension used to project the inputs.

  • linear_cls (Callable[..., Linear]) – Linear class to use to create the projection layers.

  • devices (Sequence[DeviceRef]) – DeviceRef devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.

  • has_bias (bool) – Whether to include bias terms in the linear layers.

  • activation_function (str) –

    Activation function to use. Options are:

    • silu
    • gelu
    • gelu_tanh
    • relu
    • tanh
    • sigmoid
  • float8_config (Float8Config | None) – Float8Config for float8 quantization.

  • dist_gemm_config (DistributedGemmConfig | None) – DistributedGemmConfig for distributed GEMM configuration.

  • is_sharding (bool) – Disable child layer creation during sharding.

shard()

shard(devices)

Creates sharded views of this MLP across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MLP instances, one for each device.

Return type:

list[MLP]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the MLP sharding strategy.

Was this page helpful?