Python module
linear
Multi-layer Perceptron.
ColumnParallelLinear
class max.nn.linear.ColumnParallelLinear(*args, devices: Sequence[DeviceRef], **kwargs)
A Linear layer where the weight and bias are sharded onto multiple devices.
This layer first computes for each device i in [0,…, num_devices]:
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
The values are then collected using an Allgather op, producing the same output tensor on each device:
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
Example usage:
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
Initializes the linear layer with weights and optional bias.
-
Parameters:
- in_dim – The dimensionality of the input space.
- out_dim – The dimensionality of the output space.
- dtype – The data type for both weights and bias.
- device – The target device for computation. Weights remain on CPU until moved during computation.
- name – Base name for weights (appended with
.weight
and.bias
if applicable). - has_bias – When
True
, adds a bias vector to the layer. Defaults toFalse
.
DistributedMLP
class max.nn.linear.DistributedMLP(*args, **kwargs)
A distributed multi-layer perceptron.
This class has the same state keys as the non-distributed MLP Layer.
-
Parameters:
- dtype – DType to use for the layer weights, which should match the input dtype.
- quantization_encoding – Quantization encoding of the layer weights.
- hidden_dim – The last dimension of the layer input.
- feed_forward_length – Size of dimension used to project the inputs.
- linear_cls – Linear class to use to create the projection layers.
- devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
- activation_function – Activation function to use. Options are:
- “silu”
- “gelu”
- “gelu_tanh”
- “relu”
- “tanh”
- “sigmoid”
build_subgraph()
build_subgraph(name: str, x_type: list[max.graph.type.TensorType]) → Module
Float8Config
class max.nn.linear.Float8Config(input_scale: Float8InputScaleSpec, weight_scale: Float8WeightScaleSpec, attn_in_float8: bool, embedding_output_dtype: DType | None = None)
Configures float8 quantization settings for a layer or model section.
attn_in_float8
attn_in_float8*: bool*
Whether attention projection inputs are in float8.
embedding_output_dtype
The data type of the output from the embedding layer.
input_scale
input_scale*: Float8InputScaleSpec*
Specification for input activation scaling.
weight_scale
weight_scale*: Float8WeightScaleSpec*
Specification for weight scaling.
Float8InputScaleSpec
class max.nn.linear.Float8InputScaleSpec(granularity: Float8ScaleGranularity, origin: Float8ScaleOrigin, dtype: DType, activation_scale_ub: float | None = None)
Specifies how input activations are scaled for float8 quantization.
activation_scale_ub
An optional upper bound for dynamic activation scaling.
dtype
dtype*: DType*
The data type of the input scale factor(s). Must be provided if
origin
is Float8ScaleOrigin.STATIC
.
granularity
granularity*: Float8ScaleGranularity*
The granularity of the input scale factor application.
origin
origin*: Float8ScaleOrigin*
The origin (static or dynamic) of the input scale factor.
Float8ScaleGranularity
class max.nn.linear.Float8ScaleGranularity(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Specifies the granularity of the quantization scale factor.
Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor.
BLOCK
BLOCK = 'block'
COLWISE
COLWISE = 'colwise'
ROWWISE
ROWWISE = 'rowwise'
TENSOR
TENSOR = 'tensor'
Float8ScaleOrigin
class max.nn.linear.Float8ScaleOrigin(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Specifies whether the quantization scale is determined statically or dynamically.
STATIC scales are pre-computed and loaded with the model weights. DYNAMIC scales are computed at runtime based on the input data.
DYNAMIC
DYNAMIC = 'dynamic'
STATIC
STATIC = 'static'
Float8WeightScaleSpec
class max.nn.linear.Float8WeightScaleSpec(granularity: Float8ScaleGranularity, dtype: DType)
Specifies how weights are scaled for float8 quantization.
dtype
dtype*: DType*
The data type of the weight scale factor(s).
granularity
granularity*: Float8ScaleGranularity*
The granularity of the weight scale factor application.
GPTQLinear
class max.nn.linear.GPTQLinear(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None, float8_config: Float8Config | None = None)
A Linear layer for GPTQ encoding
Initializes the linear layer with weights and optional bias with GPTQ quantization.
-
Parameters:
- in_dim – The dimensionality of the input space.
- out_dim – The dimensionality of the output space.
- dtype – The data type for both weights and bias.
- device – The target device for computation. Weights remain on CPU until moved during computation.
- has_bias – When
True
, adds a bias vector to the layer. Defaults toFalse
. - quantization_encoding – The quantization encoding of the weights.
- quantization_config – Extra config for the weight quantization.
GPTQLinearV1
class max.nn.linear.GPTQLinearV1(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None, perm_idx: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)
A Linear layer for GPTQ encoding
perm_idx
perm_idx*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None
quantization_config
quantization_config*: QuantizationConfig | None* = None
Linear
class max.nn.linear.Linear(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, float8_config: Float8Config | None = None, name: str | None = None, clip_weight: float | None = None)
Applies a linear transformation to incoming data: .
This layer implements a fully connected layer where inputs are multiplied
by a weight matrix and optionally added with a bias vector.
Both weights and bias initially reside on CPU, and the model init phase
moves them to device
.
Example:
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)
Initializes the linear layer with weights and optional bias.
-
Parameters:
- in_dim – The dimensionality of the input space.
- out_dim – The dimensionality of the output space.
- dtype – The data type for both weights and bias.
- device – The target device for computation. Weights remain on CPU until moved during computation.
- name – Base name for weights (appended with
.weight
and.bias
if applicable). - has_bias – When
True
, adds a bias vector to the layer. Defaults toFalse
.
bias
The optional bias vector stored on CPU with shape (out_dim,).
Model init moves the bias to device
if present.
device
device*: DeviceRef*
The device where matrix operations are performed.
input_scale
The optional input scale stored on CPU with shape ().
Model init moves the input_scale to device
if present.
weight
weight*: Weight*
The weight matrix stored on CPU with shape (out_dim, in_dim).
Model init transposes the weight and moves it to device
.
weight_scale
The optional weight scale stored on CPU with shape ().
Model init moves the weight_scale to device
if present.
LinearV1
class max.nn.linear.LinearV1(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)
A unified linear layer that delegates to either regular or quantized implementation.
Deprecated: Use Linear instead.
bias
bias*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None
create()
classmethod create(dtype: DType, quantization_encoding: QuantizationEncoding | None, in_features: int, out_features: int, weights: Weights | Weight, bias: Weights | Weight | None = None, quantization_config: QuantizationConfig | None = None) → LinearV1
Factory method to create a Linear layer with appropriate implementation.
weight
weight*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*
MLP
class max.nn.linear.MLP(dtype: ~max._core.dtype.DType, quantization_encoding: ~max.graph.quantization.QuantizationEncoding | None, hidden_dim: int, feed_forward_length: int, devices: ~collections.abc.Sequence[~max.graph.type.DeviceRef], linear_cls: ~typing.Callable[[...], ~max.nn.linear.Linear] = <class 'max.nn.linear.Linear'>, has_bias: bool = False, activation_function: str = 'silu', float8_config: ~max.nn.linear.Float8Config | None = None)
Simple multi-layer perceptron composed of three linear layers. Defaults to SiLU activation function.
-
Parameters:
- dtype – DType to use for the layer weights, which should match the input dtype.
- quantization_encoding – Quantization encoding of the layer weights.
- hidden_dim – The last dimension of the layer input.
- feed_forward_length – Size of dimension used to project the inputs.
- linear_cls – Linear class to use to create the projection layers.
- devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
- activation_function – Activation function to use. Options are:
- “silu”
- “gelu”
- “gelu_tanh”
- “relu”
- “tanh”
- “sigmoid”
build_subgraph()
build_subgraph(name: str, x_type: TensorType) → Module
MLPV1
class max.nn.linear.MLPV1(gate_proj: LinearV1, down_proj: LinearV1, up_proj: LinearV1)
Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function.
down_proj
down_proj*: LinearV1*
gate_proj
gate_proj*: LinearV1*
up_proj
up_proj*: LinearV1*
QLinearV1
class max.nn.linear.QLinearV1(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None)
A quantized fully connected layer.
quantization_encoding
quantization_encoding*: QuantizationEncoding | None* = None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!