Skip to main content
Log in

Python module

linear

Multi-layer Perceptron.

ColumnParallelLinear

class max.nn.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Parameters:

  • in_dim (int ) – The dimensionality of the input space.
  • out_dim (int ) – The dimensionality of the output space.
  • dtype (DType ) – The data type for both weights and bias.
  • devices (Sequence [ DeviceRef ] ) – The target devices for computation. Weights remain on CPU until sharded and moved to device during computation.
  • tied_weight (Weight | None )

DistributedMLP

class max.nn.linear.DistributedMLP(*args, **kwargs)

A distributed multi-layer perceptron.

This class has the same state keys as the non-distributed MLP Layer.

Parameters:

  • dtype – DType to use for the layer weights, which should match the input dtype.
  • quantization_encoding – Quantization encoding of the layer weights.
  • hidden_dim – The last dimension of the layer input.
  • feed_forward_length – Size of dimension used to project the inputs.
  • linear_cls – Linear class to use to create the projection layers.
  • devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
  • activation_function – Activation function to use. Options are:
    • “silu”
    • “gelu”
    • “gelu_tanh”
    • “relu”
    • “tanh”
    • “sigmoid”

Float8Config

class max.nn.linear.Float8Config(input_scale, weight_scale, mlp_in_float8, attn_qkv_in_float8, embedding_output_dtype=None, quant_method=None)

Configures float8 quantization settings for a layer or model section.

Parameters:

attn_qkv_in_float8

attn_qkv_in_float8*: set[int]*

Set of layer indices with attention QKV projections in float8.

QKV projections are considered to be either “all quantized” or all not quantized per layer. So either all of {q,k,v,o}_proj are float8, or all bfloat16.

embedding_output_dtype

embedding_output_dtype*: DType | None* = None

The data type of the output from the embedding layer.

input_scale

input_scale*: Float8InputScaleSpec*

Specification for input activation scaling.

is_dynamic

property is_dynamic*: bool*

Returns true if this input scale is dynamic.

is_static

property is_static*: bool*

Returns true if this input scale is static.

mlp_in_float8

mlp_in_float8*: set[int]*

Set of layer indices with MLPs in float8.

MLPs are considered to be either “all quantized” or all not quantized per layer. So either all of gate proj, down proj, and up proj are float8, or all bfloat16.

quant_method

quant_method*: str | None* = None

The quantization method used (e.g., “fbgemm_fp8”).

weight_scale

weight_scale*: Float8WeightScaleSpec*

Specification for weight scaling.

Float8InputScaleSpec

class max.nn.linear.Float8InputScaleSpec(granularity, origin, dtype, activation_scale_ub=None)

Specifies how input activations are scaled for float8 quantization.

Parameters:

activation_scale_ub

activation_scale_ub*: float | None* = None

An optional upper bound for dynamic activation scaling.

dtype

dtype*: DType*

The data type of the input scale factor(s).

granularity

granularity*: Float8ScaleGranularity*

The granularity of the input scale factor application.

origin

origin*: Float8ScaleOrigin*

The origin (static or dynamic) of the input scale factor.

Float8ScaleGranularity

class max.nn.linear.Float8ScaleGranularity(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies the granularity of the quantization scale factor.

Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor.

BLOCK

BLOCK = 'block'

COLWISE

COLWISE = 'colwise'

ROWWISE

ROWWISE = 'rowwise'

TENSOR

TENSOR = 'tensor'

Float8ScaleOrigin

class max.nn.linear.Float8ScaleOrigin(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies whether the quantization scale is determined statically or dynamically.

STATIC scales are pre-computed and loaded with the model weights. DYNAMIC scales are computed at runtime based on the input data.

DYNAMIC

DYNAMIC = 'dynamic'

STATIC

STATIC = 'static'

Float8WeightScaleSpec

class max.nn.linear.Float8WeightScaleSpec(granularity, dtype)

Specifies how weights are scaled for float8 quantization.

Parameters:

dtype

dtype*: DType*

The data type of the weight scale factor(s).

granularity

granularity*: Float8ScaleGranularity*

The granularity of the weight scale factor application.

is_block

property is_block*: bool*

Whether the weight scale granularity is block-wise.

is_colwise

property is_colwise*: bool*

Whether the weight scale granularity is column-wise.

is_rowwise

property is_rowwise*: bool*

Whether the weight scale granularity is row-wise.

is_tensor

property is_tensor*: bool*

Whether the weight scale granularity is per-tensor.

GPTQLinear

class max.nn.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)

A Linear layer for GPTQ encoding

Initializes the linear layer with weights and optional bias with GPTQ quantization.

Parameters:

  • in_dim (int ) – The dimensionality of the input space.
  • out_dim (int ) – The dimensionality of the output space.
  • dtype (DType ) – The data type for both weights and bias.
  • device (DeviceRef ) – The target device for computation. Weights remain on CPU until moved during computation.
  • has_bias (bool ) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None ) – The quantization encoding of the weights.
  • quantization_config (QuantizationConfig | None ) – Extra config for the weight quantization.
  • float8_config (Float8Config | None )

GPTQLinearV1

class max.nn.linear.GPTQLinearV1(weight, bias=None, quantization_encoding=None, quantization_config=None, perm_idx=None)

A Linear layer for GPTQ encoding

Parameters:

perm_idx

perm_idx*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None

quantization_config

quantization_config*: QuantizationConfig | None* = None

Linear

class max.nn.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to device.

Example:

linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

Parameters:

  • in_dim (int ) – The dimensionality of the input space.
  • out_dim (int ) – The dimensionality of the output space.
  • dtype (DType ) – The data type for both weights and bias.
  • device (DeviceRef ) – The target device for computation. Weights remain on CPU until moved during computation.
  • name (str | None ) – Base name for weights (appended with .weight and .bias if applicable).
  • has_bias (bool ) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None )
  • float8_config (Float8Config | None )
  • clip_weight (float | None )

bias

bias*: Weight | None* = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to device if present.

device

device*: DeviceRef*

The device where matrix operations are performed.

input_scale

input_scale*: Weight | None* = None

The optional input scale stored on CPU with shape (). Model init moves the input_scale to device if present.

set_sharding()

set_sharding(strategy)

Sets the weight sharding for this linear layer.

Parameters:

strategy (ShardingStrategy ) – The strategy describing the weight sharding.

Return type:

None

weight

weight*: Weight*

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to device.

weight_scale

weight_scale*: Weight | None* = None

The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to device if present.

LinearV1

class max.nn.linear.LinearV1(weight, bias=None)

A unified linear layer that delegates to either regular or quantized implementation.

Deprecated: Use Linear instead.

Parameters:

bias

bias*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None

create()

classmethod create(dtype, quantization_encoding, in_features, out_features, weights, bias=None, quantization_config=None)

Factory method to create a Linear layer with appropriate implementation.

Parameters:

Return type:

LinearV1

weight

weight*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*

MLP

class max.nn.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None)

Simple multi-layer perceptron composed of three linear layers. Defaults to SiLU activation function.

Parameters:

  • dtype (DType ) – DType to use for the layer weights, which should match the input dtype.
  • quantization_encoding (QuantizationEncoding | None ) – Quantization encoding of the layer weights.
  • hidden_dim (int ) – The last dimension of the layer input.
  • feed_forward_length (int ) – Size of dimension used to project the inputs.
  • linear_cls (Callable [ ... , Linear ] ) – Linear class to use to create the projection layers.
  • devices (Sequence [ DeviceRef ] ) – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
  • activation_function (str ) – Activation function to use. Options are:
    • “silu”
    • “gelu”
    • “gelu_tanh”
    • “relu”
    • “tanh”
    • “sigmoid”
  • has_bias (bool )
  • float8_config (Float8Config | None )

MLPV1

class max.nn.linear.MLPV1(gate_proj, down_proj, up_proj)

Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function.

Parameters:

down_proj

down_proj*: LinearV1*

gate_proj

gate_proj*: LinearV1*

up_proj

up_proj*: LinearV1*

QLinearV1

class max.nn.linear.QLinearV1(weight, bias=None, quantization_encoding=None)

A quantized fully connected layer.

Parameters:

quantization_encoding

quantization_encoding*: QuantizationEncoding | None* = None