Python module
linear
Multi-layer Perceptron.
ColumnParallelLinear
class max.nn.legacy.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)
A Linear layer where the weight and bias are sharded onto multiple devices.
This layer first computes for each device i in [0,…, num_devices]:
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+The values are then collected using an Allgather op, producing the same output tensor on each device:
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+Example usage:
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)DistributedGemmConfig
class max.nn.legacy.linear.DistributedGemmConfig(enable_matmul_allreduce)
Configure how distributed GEMM is executed.
Configuration for distributed General Matrix Multiply operations.
-
Parameters:
-
enable_matmul_allreduce (bool)
enable_matmul_allreduce
enable_matmul_allreduce: bool
If True, use the matmul + all_reduce kernel.
generate()
static generate()
Returns the default DistributedGemmConfig.
-
Returns:
-
A
DistributedGemmConfiginstance with default settings. -
Return type:
-
DistributedGemmConfig | None
GPTQLinear
class max.nn.legacy.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)
A Linear layer for GPTQ encoding.
-
Parameters:
-
- in_dim (int)
- out_dim (int)
- dtype (DType)
- device (DeviceRef)
- has_bias (bool)
- quantization_encoding (QuantizationEncoding | None)
- quantization_config (QuantizationConfig | None)
- float8_config (Float8Config | None)
Linear
class max.nn.legacy.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None, is_sharding=False)
Applies a linear transformation to incoming data: .
This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device.
Example:
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)-
Parameters:
-
- in_dim (int)
- out_dim (int)
- dtype (DType)
- device (DeviceRef)
- has_bias (bool)
- quantization_encoding (QuantizationEncoding | None)
- float8_config (Float8Config | None)
- name (str | None)
- clip_weight (float | None)
- is_sharding (bool)
bias
The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.
device
device: DeviceRef
The device where matrix operations are performed.
input_scale
The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.
shard()
shard(devices)
Creates sharded views of this Linear layer across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the weight sharding strategy.
weight
weight: Weight
The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.
weight_scale
The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.
MLP
class max.nn.legacy.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.legacy.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None, dist_gemm_config=None, is_sharding=False)
Simple multi-layer perceptron composed of three Linear layers.
Defaults to SiLU activation function.
-
Parameters:
-
- dtype (DType)
- quantization_encoding (QuantizationEncoding | None)
- hidden_dim (int)
- feed_forward_length (int)
- devices (Sequence[DeviceRef])
- linear_cls (Callable[..., Linear])
- has_bias (bool)
- activation_function (str)
- float8_config (Float8Config | None)
- dist_gemm_config (DistributedGemmConfig | None)
- is_sharding (bool)
shard()
shard(devices)
Creates sharded views of this MLP across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the MLP sharding strategy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!