Skip to main content

Python module

linear

Multi-layer Perceptron.

ColumnParallelLinear

class max.nn.legacy.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
|     |       | W_0 |       | b_0 |       | y_0 | GPU0
|     |       +-----+       +-----+       +-----+
|     |       | W_1 |       | b_1 |       | y_1 | GPU1
|  x  |   @   +-----+   +   +-----+   =   +-----+
|     |       | W_2 |       | b_2 |       | y_2 | GPU2
|     |       +-----+       +-----+       +-----+
|     |       | W_3 |       | b_3 |       | y_3 | GPU3
+-----+       +-----+       +-----+       +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
| y_0 |  -  |  -  |  -  |                   | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  | y_1 |  -  |  -  |                   | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+  -- Allgather --> +-----+-----+-----+-----+
|  -  |  -  | y_2 |  -  |                   | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  |  -  |  -  | y_3 |                   | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
    in_dim,
    out_dim,
    DType.float32,
    devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Parameters:

DistributedGemmConfig

class max.nn.legacy.linear.DistributedGemmConfig(enable_matmul_allreduce)

Configure how distributed GEMM is executed.

Configuration for distributed General Matrix Multiply operations.

Parameters:

enable_matmul_allreduce (bool)

enable_matmul_allreduce

enable_matmul_allreduce: bool

If True, use the matmul + all_reduce kernel.

generate()

static generate()

Returns the default DistributedGemmConfig.

Returns:

A DistributedGemmConfig instance with default settings.

Return type:

DistributedGemmConfig | None

GPTQLinear

class max.nn.legacy.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)

A Linear layer for GPTQ encoding.

Parameters:

Linear

class max.nn.legacy.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None, is_sharding=False)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device.

Example:

linear_layer = Linear(
    in_dim=256,
    out_dim=128,
    dtype=DType.float32,
    device=DeviceRef.GPU(),
    name="linear",
    has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Parameters:

bias

bias: Weight | None = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.

device

device: DeviceRef

The device where matrix operations are performed.

input_scale

input_scale: Weight | None = None

The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.

shard()

shard(devices)

Creates sharded views of this Linear layer across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of DeviceRef devices to place the shards on.

Returns:

List of sharded Linear instances, one for each device.

Return type:

list[Linear]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the weight sharding strategy.

weight

weight: Weight

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.

weight_scale

weight_scale: Weight | None = None

The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.

MLP

class max.nn.legacy.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.legacy.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None, dist_gemm_config=None, is_sharding=False)

Simple multi-layer perceptron composed of three Linear layers.

Defaults to SiLU activation function.

Parameters:

shard()

shard(devices)

Creates sharded views of this MLP across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MLP instances, one for each device.

Return type:

list[MLP]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the MLP sharding strategy.

Was this page helpful?