Skip to main content

Python class

Linear

Linear

class max.nn.Linear(in_dim, out_dim, dtype, device, has_bias=False, quantization_encoding=None, quant_config=None, name=None, clip_weight=None, is_sharding=False)

source

Bases: Module, Shardable

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. When called, Linear accepts a TensorValue of shape (..., in_dim) and returns a TensorValue of shape (..., out_dim).

Both weights and bias initially reside on CPU and are moved to the specified device during model initialization.

linear_layer = Linear(
    in_dim=256,
    out_dim=128,
    dtype=DType.float32,
    device=DeviceRef.GPU(),
    name="linear",
    has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • device (DeviceRef) – The target DeviceRef for computation. Weights remain on CPU until moved during computation.
  • name (str | None) – Base name for weights (appended with .weight and .bias if applicable).
  • has_bias (bool) – When True, adds a bias vector to the layer. Defaults to False.
  • quantization_encoding (QuantizationEncoding | None) – QuantizationEncoding for the weights.
  • quant_config (QuantConfig | None) – QuantConfig for scaled quantization.
  • clip_weight (float | None) – Optional weight clipping threshold.
  • is_sharding (bool) – Disable child layer creation during sharding.

bias

bias: Weight | None = None

source

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to the target device if present.

device

device: DeviceRef

source

The device where matrix operations are performed.

input_scale

input_scale: Weight | None = None

source

The optional input scale stored on CPU with shape (). Model init moves the input_scale to the target device if present.

shard()

shard(devices)

source

Creates sharded views of this Linear layer across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of DeviceRef devices to place the shards on.

Returns:

List of sharded Linear instances, one for each device.

Return type:

list[Linear]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the weight sharding strategy.

weight

weight: Weight

source

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to the target device.

weight_scale

weight_scale: Weight | None = None

source

The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight_scale to the target device if present.

weight_scale_2

weight_scale_2: Weight | None = None

source

The optional weight scale 2 used for fp4 quantization.