Python class
Linear
Linear
class max.nn.Linear(in_dim, out_dim, *, bias=True)
A unary linear transformation over an input tensor.
Linear is defined as f(x) = x @ W.T + B where W is the weight tensor and B is an optional bias tensor.
If W is not square then the transformation represents a dimensionality change. By convention the weight tensor is stored transposed.
from max.nn import Linear
from max.tensor import Tensor
model = Linear(5, 10)
assert dict(model.parameters) == {
"weight": model.weight, "bias": model.bias
}
result = model(Tensor.ones([5]))
assert result.shape == [10]Constructs a random linear transformation of the given dimensions.
-
Parameters:
bias
The bias Tensor for the linear transformation (or 0 if bias is disabled).
forward()
forward(x)
Applies a linear transformation to the input tensor.
Linear is defined as f(x) = x @ W.T + B where W is the weight tensor and B is an optional bias tensor.
in_dim
property in_dim: Dim
The input dimension for the transformation.
out_dim
property out_dim: Dim
The output dimension for the transformation.
weight
weight: Tensor
The weight Tensor for the linear transformation.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!