Skip to main content
Log in

Python module

linear

Multi-layer Perceptron.

ColumnParallelLinear

class max.nn.linear.ColumnParallelLinear(*args, devices: Sequence[DeviceRef], **kwargs)

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Initializes the linear layer with weights and optional bias.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • name – Base name for weights (appended with .weight and .bias if applicable).
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.

DistributedMLP

class max.nn.linear.DistributedMLP(*args, **kwargs)

A distributed multi-layer perceptron.

This class has the same state keys as the non-distributed MLP Layer.

  • Parameters:

    • dtype – DType to use for the layer weights, which should match the input dtype.
    • quantization_encoding – Quantization encoding of the layer weights.
    • hidden_dim – The last dimension of the layer input.
    • feed_forward_length – Size of dimension used to project the inputs.
    • linear_cls – Linear class to use to create the projection layers.
    • devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
    • activation_function – Activation function to use. Options are:
      • “silu”
      • “gelu”
      • “gelu_tanh”
      • “relu”
      • “tanh”
      • “sigmoid”

build_subgraph()

build_subgraph(name: str, x_type: list[max.graph.type.TensorType]) → Module

Float8Config

class max.nn.linear.Float8Config(input_scale: Float8InputScaleSpec, weight_scale: Float8WeightScaleSpec, attn_in_float8: bool, embedding_output_dtype: DType | None = None)

Configures float8 quantization settings for a layer or model section.

attn_in_float8

attn_in_float8*: bool*

Whether attention projection inputs are in float8.

embedding_output_dtype

embedding_output_dtype*: DType | None* = None

The data type of the output from the embedding layer.

input_scale

input_scale*: Float8InputScaleSpec*

Specification for input activation scaling.

weight_scale

weight_scale*: Float8WeightScaleSpec*

Specification for weight scaling.

Float8InputScaleSpec

class max.nn.linear.Float8InputScaleSpec(granularity: Float8ScaleGranularity, origin: Float8ScaleOrigin, dtype: DType, activation_scale_ub: float | None = None)

Specifies how input activations are scaled for float8 quantization.

activation_scale_ub

activation_scale_ub*: float | None* = None

An optional upper bound for dynamic activation scaling.

dtype

dtype*: DType*

The data type of the input scale factor(s). Must be provided if origin is Float8ScaleOrigin.STATIC.

granularity

granularity*: Float8ScaleGranularity*

The granularity of the input scale factor application.

origin

origin*: Float8ScaleOrigin*

The origin (static or dynamic) of the input scale factor.

Float8ScaleGranularity

class max.nn.linear.Float8ScaleGranularity(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies the granularity of the quantization scale factor.

Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor.

BLOCK

BLOCK = 'block'

COLWISE

COLWISE = 'colwise'

ROWWISE

ROWWISE = 'rowwise'

TENSOR

TENSOR = 'tensor'

Float8ScaleOrigin

class max.nn.linear.Float8ScaleOrigin(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies whether the quantization scale is determined statically or dynamically.

STATIC scales are pre-computed and loaded with the model weights. DYNAMIC scales are computed at runtime based on the input data.

DYNAMIC

DYNAMIC = 'dynamic'

STATIC

STATIC = 'static'

Float8WeightScaleSpec

class max.nn.linear.Float8WeightScaleSpec(granularity: Float8ScaleGranularity, dtype: DType)

Specifies how weights are scaled for float8 quantization.

dtype

dtype*: DType*

The data type of the weight scale factor(s).

granularity

granularity*: Float8ScaleGranularity*

The granularity of the weight scale factor application.

GPTQLinear

class max.nn.linear.GPTQLinear(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None, float8_config: Float8Config | None = None)

A Linear layer for GPTQ encoding

Initializes the linear layer with weights and optional bias with GPTQ quantization.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.
    • quantization_encoding – The quantization encoding of the weights.
    • quantization_config – Extra config for the weight quantization.

GPTQLinearV1

class max.nn.linear.GPTQLinearV1(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None, perm_idx: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)

A Linear layer for GPTQ encoding

perm_idx

perm_idx*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None

quantization_config

quantization_config*: QuantizationConfig | None* = None

Linear

class max.nn.linear.Linear(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, float8_config: Float8Config | None = None, name: str | None = None, clip_weight: float | None = None)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to device.

Example:

linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • name – Base name for weights (appended with .weight and .bias if applicable).
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.

bias

bias*: Weight | None* = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to device if present.

device

device*: DeviceRef*

The device where matrix operations are performed.

input_scale

input_scale*: Weight | None* = None

The optional input scale stored on CPU with shape (). Model init moves the input_scale to device if present.

weight

weight*: Weight*

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to device.

weight_scale

weight_scale*: Weight | None* = None

The optional weight scale stored on CPU with shape (). Model init moves the weight_scale to device if present.

LinearV1

class max.nn.linear.LinearV1(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)

A unified linear layer that delegates to either regular or quantized implementation.

Deprecated: Use Linear instead.

bias

bias*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None

create()

classmethod create(dtype: DType, quantization_encoding: QuantizationEncoding | None, in_features: int, out_features: int, weights: Weights | Weight, bias: Weights | Weight | None = None, quantization_config: QuantizationConfig | None = None) → LinearV1

Factory method to create a Linear layer with appropriate implementation.

weight

weight*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*

MLP

class max.nn.linear.MLP(dtype: ~max._core.dtype.DType, quantization_encoding: ~max.graph.quantization.QuantizationEncoding | None, hidden_dim: int, feed_forward_length: int, devices: ~collections.abc.Sequence[~max.graph.type.DeviceRef], linear_cls: ~typing.Callable[[...], ~max.nn.linear.Linear] = <class 'max.nn.linear.Linear'>, has_bias: bool = False, activation_function: str = 'silu', float8_config: ~max.nn.linear.Float8Config | None = None)

Simple multi-layer perceptron composed of three linear layers. Defaults to SiLU activation function.

  • Parameters:

    • dtype – DType to use for the layer weights, which should match the input dtype.
    • quantization_encoding – Quantization encoding of the layer weights.
    • hidden_dim – The last dimension of the layer input.
    • feed_forward_length – Size of dimension used to project the inputs.
    • linear_cls – Linear class to use to create the projection layers.
    • devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
    • activation_function – Activation function to use. Options are:
      • “silu”
      • “gelu”
      • “gelu_tanh”
      • “relu”
      • “tanh”
      • “sigmoid”

build_subgraph()

build_subgraph(name: str, x_type: TensorType) → Module

MLPV1

class max.nn.linear.MLPV1(gate_proj: LinearV1, down_proj: LinearV1, up_proj: LinearV1)

Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function.

down_proj

down_proj*: LinearV1*

gate_proj

gate_proj*: LinearV1*

up_proj

up_proj*: LinearV1*

QLinearV1

class max.nn.linear.QLinearV1(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None)

A quantized fully connected layer.

quantization_encoding

quantization_encoding*: QuantizationEncoding | None* = None