Python module
linear
Multi-layer Perceptron.
ColumnParallelLinear
class max.nn.linear.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)
A Linear layer where the weight and bias are sharded onto multiple devices.
This layer first computes for each device i in [0,…, num_devices]:
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
The values are then collected using an Allgather op, producing the same output tensor on each device:
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
Example usage:
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
-
Parameters:
-
- in_dim (
int
) – The dimensionality of the input space. - out_dim (
int
) – The dimensionality of the output space. - dtype (
DType
) – The data type for both weights and bias. - devices (
Sequence
[
DeviceRef
]
) – The target devices for computation. Weights remain on CPU until sharded and moved to device during computation. - tied_weight (
Weight
|
None
)
- in_dim (
DistributedMLP
class max.nn.linear.DistributedMLP(*args, **kwargs)
A distributed multi-layer perceptron.
This class has the same state keys as the non-distributed MLP Layer.
-
Parameters:
-
- dtype – DType to use for the layer weights, which should match the input dtype.
- quantization_encoding – Quantization encoding of the layer weights.
- hidden_dim – The last dimension of the layer input.
- feed_forward_length – Size of dimension used to project the inputs.
- linear_cls – Linear class to use to create the projection layers.
- devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
- activation_function – Activation function to use. Options are:
- “silu”
- “gelu”
- “gelu_tanh”
- “relu”
- “tanh”
- “sigmoid”
Float8Config
class max.nn.linear.Float8Config(input_scale, weight_scale, mlp_in_float8, attn_qkv_in_float8, embedding_output_dtype=None, quant_method=None)
Configures float8 quantization settings for a layer or model section.
-
Parameters:
-
- input_scale (
Float8InputScaleSpec
) - weight_scale (
Float8WeightScaleSpec
) - mlp_in_float8 (
set
[
int
]
) - attn_qkv_in_float8 (
set
[
int
]
) - embedding_output_dtype (
DType
|
None
) - quant_method (
str
|
None
)
- input_scale (
attn_qkv_in_float8
Set of layer indices with attention QKV projections in float8.
QKV projections are considered to be either “all quantized” or all not quantized per layer. So either all of {q,k,v,o}_proj are float8, or all bfloat16.
embedding_output_dtype
The data type of the output from the embedding layer.
input_scale
input_scale*: Float8InputScaleSpec*
Specification for input activation scaling.
is_dynamic
property is_dynamic*: bool*
Returns true if this input scale is dynamic.
is_static
property is_static*: bool*
Returns true if this input scale is static.
mlp_in_float8
Set of layer indices with MLPs in float8.
MLPs are considered to be either “all quantized” or all not quantized per layer. So either all of gate proj, down proj, and up proj are float8, or all bfloat16.
quant_method
The quantization method used (e.g., “fbgemm_fp8”).
weight_scale
weight_scale*: Float8WeightScaleSpec*
Specification for weight scaling.
Float8InputScaleSpec
class max.nn.linear.Float8InputScaleSpec(granularity, origin, dtype, activation_scale_ub=None)
Specifies how input activations are scaled for float8 quantization.
-
Parameters:
-
- granularity (
Float8ScaleGranularity
) - origin (
Float8ScaleOrigin
) - dtype (
DType
) - activation_scale_ub (
float
|
None
)
- granularity (
activation_scale_ub
An optional upper bound for dynamic activation scaling.
dtype
dtype*: DType*
The data type of the input scale factor(s).
granularity
granularity*: Float8ScaleGranularity*
The granularity of the input scale factor application.
origin
origin*: Float8ScaleOrigin*
The origin (static or dynamic) of the input scale factor.
Float8ScaleGranularity
class max.nn.linear.Float8ScaleGranularity(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Specifies the granularity of the quantization scale factor.
Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor.
BLOCK
BLOCK = 'block'
COLWISE
COLWISE = 'colwise'
ROWWISE
ROWWISE = 'rowwise'
TENSOR
TENSOR = 'tensor'
Float8ScaleOrigin
class max.nn.linear.Float8ScaleOrigin(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Specifies whether the quantization scale is determined statically or dynamically.
STATIC scales are pre-computed and loaded with the model weights. DYNAMIC scales are computed at runtime based on the input data.
DYNAMIC
DYNAMIC = 'dynamic'
STATIC
STATIC = 'static'
Float8WeightScaleSpec
class max.nn.linear.Float8WeightScaleSpec(granularity, dtype)
Specifies how weights are scaled for float8 quantization.
-
Parameters:
-
- granularity (
Float8ScaleGranularity
) - dtype (
DType
)
- granularity (
dtype
dtype*: DType*
The data type of the weight scale factor(s).
granularity
granularity*: Float8ScaleGranularity*
The granularity of the weight scale factor application.
is_block
property is_block*: bool*
Whether the weight scale granularity is block-wise.
is_colwise
property is_colwise*: bool*
Whether the weight scale granularity is column-wise.
is_rowwise
property is_rowwise*: bool*
Whether the weight scale granularity is row-wise.
is_tensor
property is_tensor*: bool*
Whether the weight scale granularity is per-tensor.
GPTQLinear
class max.nn.linear.GPTQLinear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quantization_config=None, float8_config=None)
A Linear layer for GPTQ encoding
Initializes the linear layer with weights and optional bias with GPTQ quantization.
-
Parameters:
-
- in_dim (
int
) – The dimensionality of the input space. - out_dim (
int
) – The dimensionality of the output space. - dtype (
DType
) – The data type for both weights and bias. - device (
DeviceRef
) – The target device for computation. Weights remain on CPU until moved during computation. - has_bias (
bool
) – WhenTrue
, adds a bias vector to the layer. Defaults toFalse
. - quantization_encoding (
QuantizationEncoding
|
None
) – The quantization encoding of the weights. - quantization_config (
QuantizationConfig
|
None
) – Extra config for the weight quantization. - float8_config (
Float8Config
|
None
)
- in_dim (
GPTQLinearV1
class max.nn.linear.GPTQLinearV1(weight, bias=None, quantization_encoding=None, quantization_config=None, perm_idx=None)
A Linear layer for GPTQ encoding
-
Parameters:
-
- weight (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
) - bias (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
|
None
) - quantization_encoding (
QuantizationEncoding
|
None
) - quantization_config (
QuantizationConfig
|
None
) - perm_idx (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
|
None
)
- weight (
perm_idx
perm_idx*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None
quantization_config
quantization_config*: QuantizationConfig | None* = None
Linear
class max.nn.linear.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, float8_config=None, name=None, clip_weight=None)
Applies a linear transformation to incoming data: .
This layer implements a fully connected layer where inputs are multiplied
by a weight matrix and optionally added with a bias vector.
Both weights and bias initially reside on CPU, and the model init phase
moves them to device
.
Example:
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)
linear_layer = Linear(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)
input_tensor: TensorValue
output = linear_layer(input_tensor)
Initializes the linear layer with weights and optional bias.
-
Parameters:
-
- in_dim (
int
) – The dimensionality of the input space. - out_dim (
int
) – The dimensionality of the output space. - dtype (
DType
) – The data type for both weights and bias. - device (
DeviceRef
) – The target device for computation. Weights remain on CPU until moved during computation. - name (
str
|
None
) – Base name for weights (appended with.weight
and.bias
if applicable). - has_bias (
bool
) – WhenTrue
, adds a bias vector to the layer. Defaults toFalse
. - quantization_encoding (
QuantizationEncoding
|
None
) - float8_config (
Float8Config
|
None
) - clip_weight (
float
|
None
)
- in_dim (
bias
The optional bias vector stored on CPU with shape (out_dim,).
Model init moves the bias to device
if present.
device
device*: DeviceRef*
The device where matrix operations are performed.
input_scale
The optional input scale stored on CPU with shape ().
Model init moves the input_scale to device
if present.
set_sharding()
set_sharding(strategy)
Sets the weight sharding for this linear layer.
-
Parameters:
-
strategy (
ShardingStrategy
) – The strategy describing the weight sharding. -
Return type:
-
None
weight
weight*: Weight*
The weight matrix stored on CPU with shape (out_dim, in_dim).
Model init transposes the weight and moves it to device
.
weight_scale
The optional weight scale stored on CPU with shape () or (N,).
Model init moves the weight_scale to device
if present.
LinearV1
class max.nn.linear.LinearV1(weight, bias=None)
A unified linear layer that delegates to either regular or quantized implementation.
Deprecated: Use Linear instead.
-
Parameters:
bias
bias*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None
create()
classmethod create(dtype, quantization_encoding, in_features, out_features, weights, bias=None, quantization_config=None)
Factory method to create a Linear layer with appropriate implementation.
-
Parameters:
-
- dtype (
DType
) - quantization_encoding (
QuantizationEncoding
|
None
) - in_features (
int
) - out_features (
int
) - weights (
Weights
|
Weight
) - bias (
Weights
|
Weight
|
None
) - quantization_config (
QuantizationConfig
|
None
)
- dtype (
-
Return type:
weight
weight*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*
MLP
class max.nn.linear.MLP(dtype, quantization_encoding, hidden_dim, feed_forward_length, devices, linear_cls=<class 'max.nn.linear.Linear'>, has_bias=False, activation_function='silu', float8_config=None)
Simple multi-layer perceptron composed of three linear layers. Defaults to SiLU activation function.
-
Parameters:
-
- dtype (
DType
) – DType to use for the layer weights, which should match the input dtype. - quantization_encoding (
QuantizationEncoding
|
None
) – Quantization encoding of the layer weights. - hidden_dim (
int
) – The last dimension of the layer input. - feed_forward_length (
int
) – Size of dimension used to project the inputs. - linear_cls (
Callable
[
...
,
Linear
]
) – Linear class to use to create the projection layers. - devices (
Sequence
[
DeviceRef
]
) – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices. - activation_function (
str
) – Activation function to use. Options are:- “silu”
- “gelu”
- “gelu_tanh”
- “relu”
- “tanh”
- “sigmoid”
- has_bias (
bool
) - float8_config (
Float8Config
|
None
)
- dtype (
MLPV1
class max.nn.linear.MLPV1(gate_proj, down_proj, up_proj)
Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function.
down_proj
down_proj*: LinearV1*
gate_proj
gate_proj*: LinearV1*
up_proj
up_proj*: LinearV1*
QLinearV1
class max.nn.linear.QLinearV1(weight, bias=None, quantization_encoding=None)
A quantized fully connected layer.
-
Parameters:
quantization_encoding
quantization_encoding*: QuantizationEncoding | None* = None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!