Skip to main content

Python class

Linear

Linear

class max.nn.Linear(in_dim, out_dim, *, bias=True)

A unary linear transformation over an input tensor.

Linear is defined as f(x) = x @ W.T + B where W is the weight tensor and B is an optional bias tensor.

If W is not square then the transformation represents a dimensionality change. By convention the weight tensor is stored transposed.

from max.nn import Linear
from max.tensor import Tensor

model = Linear(5, 10)

assert dict(model.parameters) == {
    "weight": model.weight, "bias": model.bias
}

result = model(Tensor.ones([5]))
assert result.shape == [10]

Constructs a random linear transformation of the given dimensions.

Parameters:

  • in_dim (DimLike) – The dimensionality of the input to the transformation
  • out_dim (DimLike) – The dimensionality after applying the transformation to the input tensor of dim in_dim.
  • bias (Tensor | Literal[0]) – Whether to use a bias in the transformation.

bias

bias: Tensor | Literal[0]

The bias Tensor for the linear transformation (or 0 if bias is disabled).

forward()

forward(x)

Applies a linear transformation to the input tensor.

Linear is defined as f(x) = x @ W.T + B where W is the weight tensor and B is an optional bias tensor.

Parameters:

x (Tensor) – The input tensor

Returns:

The result of applying the linear transformation to the tensor.

Return type:

Tensor

in_dim

property in_dim: Dim

The input dimension for the transformation.

out_dim

property out_dim: Dim

The output dimension for the transformation.

weight

weight: Tensor

The weight Tensor for the linear transformation.

Was this page helpful?