Skip to main content

Python module

linear

Multi-layer Perceptron.

ColumnParallelLinear

class max.nn.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • devices (Sequence[DeviceRef]) – The target DeviceRef devices for computation. Weights remain on CPU until sharded and moved to device during computation.
  • tied_weight (Weight | None) – Optional Weight to tie with this layer.

DistributedGemmConfig

class max.nn.linear.DistributedGemmConfig(enable_matmul_allreduce)

Configure how distributed GEMM is executed.

Parameters:

enable_matmul_allreduce (bool)

enable_matmul_allreduce

enable_matmul_allreduce: bool

If True, use the matmul + all_reduce kernel.

generate()

static generate()

Returns the default DistributedGemmConfig.

Returns:

A DistributedGemmConfig instance with default settings.

Return type:

DistributedGemmConfig | None

Float8Config

class max.nn.linear.Float8Config(input_scale, weight_scale, mlp_in_float8, attn_qkv_in_float8, embedding_output_dtype=None, quant_method=None)

Configures float8 quantization settings for a layer or model section.

Parameters:

attn_qkv_in_float8

attn_qkv_in_float8: set[int]

Set of layer indices with attention QKV projections in float8.

QKV projections are considered to be either “all quantized” or all not quantized per layer. So either all of {q,k,v,o}_proj are float8, or all bfloat16.

embedding_output_dtype

embedding_output_dtype: DType | None = None

The DType of the output from the embedding layer.

input_scale

input_scale: Float8InputScaleSpec

Float8InputScaleSpec for input activation scaling.

is_dynamic

property is_dynamic: bool

Returns True if this input scale is dynamic.

is_static

property is_static: bool

Returns True if this input scale is static.

mlp_in_float8

mlp_in_float8: set[int]

Set of layer indices with MLPs in float8.

MLPs are considered to be either “all quantized” or all not quantized per layer. So either all of gate proj, down proj, and up proj are float8, or all bfloat16.

quant_method

quant_method: str | None = None

The quantization method used (e.g., “fbgemm_fp8”).

weight_scale

weight_scale: Float8WeightScaleSpec

Float8WeightScaleSpec for weight scaling.

Float8InputScaleSpec

class max.nn.linear.Float8InputScaleSpec(granularity, origin, dtype, activation_scale_ub=None)

Specifies how input activations are scaled for float8 quantization.

Parameters:

activation_scale_ub

activation_scale_ub: float | None = None

An optional upper bound for dynamic activation scaling.

dtype

dtype: DType

The DType of the input scale factor(s).

granularity

granularity: Float8ScaleGranularity

The Float8ScaleGranularity of the input scale factor application.

origin

origin: Float8ScaleOrigin

The Float8ScaleOrigin (static or dynamic) of the input scale factor.

Float8ScaleGranularity

class max.nn.linear.Float8ScaleGranularity(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies the granularity of the quantization scale factor.

Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor.

BLOCK

BLOCK = 'block'

Per-block scaling.

COLWISE

COLWISE = 'colwise'

Per-column scaling.

ROWWISE

ROWWISE = 'rowwise'

Per-row scaling.

TENSOR

TENSOR = 'tensor'

Per-tensor scaling.

Float8ScaleOrigin

class max.nn.linear.Float8ScaleOrigin(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies whether the quantization scale is determined statically or dynamically.

DYNAMIC

DYNAMIC = 'dynamic'

Scales are computed at runtime based on the input data.

STATIC

STATIC = 'static'

Scales are pre-computed and loaded with the model weights.

Float8WeightScaleSpec

class max.nn.linear.Float8WeightScaleSpec(granularity, dtype)

Specifies how weights are scaled for float8 quantization.

Parameters:

dtype

dtype: DType

The DType of the weight scale factor(s).

granularity

granularity: Float8ScaleGranularity

The Float8ScaleGranularity of the weight scale factor application.

is_block

property is_block: bool

Whether the weight scale granularity is block-wise.

is_colwise

property is_colwise: bool

Whether the weight scale granularity is column-wise.

is_rowwise

property is_rowwise: bool

Whether the weight scale granularity is row-wise.

is_tensor

property is_tensor: bool

Whether the weight scale granularity is per-tensor.

GPTQLinear

class max.nn.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)

A Linear layer for GPTQ encoding.

Initializes the linear layer with weights and optional bias with GPTQ quantization.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • device (DeviceRef) – The target DeviceRef for computation. Weights remain on CPU until moved during computation.
  • has_bias (bool) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None) – The QuantizationEncoding of the weights.
  • quantization_config (QuantizationConfig | None) – Extra QuantizationConfig for the weight quantization.
  • float8_config (Float8Config | None) – Float8Config for float8 quantization (not supported).

GPTQLinearV1

class max.nn.linear.GPTQLinearV1(weight, bias=None, quantization_encoding=None, quantization_config=None, perm_idx=None)

A Linear layer for GPTQ encoding.

Deprecated

Deprecated since version 25.5: Use GPTQLinear instead.

Parameters:

perm_idx

perm_idx: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer | floating | ndarray | None = None

Optional permutation indices for GPTQ quantization.

quantization_config

quantization_config: QuantizationConfig | None = None

The QuantizationConfig for GPTQ quantization.

Linear

class max.nn.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to the specified device.

Example:

linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • device (DeviceRef) – The target DeviceRef for computation. Weights remain on CPU until moved during computation.
  • name (str | None) – Base name for weights (appended with .weight and .bias if applicable).
  • has_bias (bool) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding for the weights.
  • float8_config (Float8Config | None) – Float8Config for float8 quantization.
  • clip_weight (float | None) – Optional weight clipping threshold.

bias

bias: Weight | None = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.

device

device: DeviceRef

The device where matrix operations are performed.

input_scale

input_scale: Weight | None = None

The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.

shard()

shard(devices)

Creates sharded views of this Linear layer across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of DeviceRef devices to place the shards on.

Returns:

List of sharded Linear instances, one for each device.

Return type:

list[Linear]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the weight sharding strategy.

weight

weight: Weight

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.

weight_scale

weight_scale: Weight | None = None

The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.

LinearV1

class max.nn.linear.LinearV1(weight, bias=None)

A unified linear layer that delegates to either regular or quantized implementation.

Deprecated

Deprecated since version 25.5: Use Linear instead.

Parameters:

bias

bias: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer | floating | ndarray | None = None

Optional bias tensor for the linear transformation.

create()

classmethod create(dtype, quantization_encoding, in_features, out_features, weights, bias=None, quantization_config=None)

Factory method to create a LinearV1 layer with appropriate implementation.

Parameters:

  • dtype (DType) – The DType for the layer.
  • quantization_encoding (QuantizationEncoding | None) – The QuantizationEncoding for the weights.
  • in_features (int) – The input feature dimension.
  • out_features (int) – The output feature dimension.
  • weights (Weights | Weight) – The Weights or Weight object for the layer.
  • bias (Weights | Weight | None) – Optional Weights or Weight object for bias.
  • quantization_config (QuantizationConfig | None) – Optional QuantizationConfig for quantization.

Returns:

A LinearV1 instance.

Return type:

LinearV1

weight

weight: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer | floating | ndarray

The weight tensor for the linear transformation.

MLP

class max.nn.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None, dist_gemm_config=None)

Simple multi-layer perceptron composed of three Linear layers. Defaults to SiLU activation function.

Parameters:

  • dtype (DType) – DType to use for the layer weights, which should match the input dtype.

  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding of the layer weights.

  • hidden_dim (int) – The last dimension of the layer input.

  • feed_forward_length (int) – Size of dimension used to project the inputs.

  • linear_cls (Callable[..., Linear]) – Linear class to use to create the projection layers.

  • devices (Sequence[DeviceRef]) – DeviceRef devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.

  • has_bias (bool) – Whether to include bias terms in the linear layers.

  • activation_function (str) –

    Activation function to use. Options are:

    • silu
    • gelu
    • gelu_tanh
    • relu
    • tanh
    • sigmoid
  • float8_config (Float8Config | None) – Float8Config for float8 quantization.

  • dist_gemm_config (DistributedGemmConfig | None) – DistributedGemmConfig for distributed GEMM configuration.

shard()

shard(devices)

Creates sharded views of this MLP across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MLP instances, one for each device.

Return type:

list[MLP]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the MLP sharding strategy.

MLPV1

class max.nn.linear.MLPV1(gate_proj, down_proj, up_proj)

Simple multi-layer perceptron composed of three LinearV1 layers. Uses SiLU activation function.

Deprecated

Deprecated since version 25.5: Use MLP instead.

Parameters:

down_proj

down_proj: LinearV1

The down projection LinearV1 layer.

gate_proj

gate_proj: LinearV1

The gate projection LinearV1 layer.

up_proj

up_proj: LinearV1

The up projection LinearV1 layer.

QLinearV1

class max.nn.linear.QLinearV1(weight, bias=None, quantization_encoding=None)

A quantized fully connected layer.

Deprecated

Deprecated since version 25.5: Use Linear instead.

Parameters:

quantization_encoding

quantization_encoding: QuantizationEncoding | None = None

The QuantizationEncoding for the quantized weights.

Was this page helpful?