Python module
module_v3
Module implementation using eager tensors.
Embedding
class max.nn.module_v3.Embedding(vocab_size, *, dim=None, dims=None)
A vector embedding.
An embedding can be thought of as a lookup table for vectors by index. Given an input tensor of indices into the embedding, the result of the embedding lookup is a tensor of the same shape, but with each index replaced by the value of the vector in that location in the embedding table.
The common case for embeddings is a 1-dimensional embedding:
from max.dtype import DType
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding
embedding = Embedding(vocab_size=1000, dim=128)
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 128]
However they just as easily support multi-dimensional embeddings:
from max.dtype import DType
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding
embedding = Embedding(vocab_size=1000, dims=[16, 128])
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 16, 128]
Creates a randomly initialized embedding of the specified size.
-
Parameters:
-
- vocab_size (DimLike) – The number of elements in the lookup table. Indices outside the range of [0, index_size) are illegal in the resulting embedding operation.
- dim (DimLike | None) – The embedding dimension if there is exactly one. Equivalent to dims=[dim].
- dims (ShapeLike | None) – For specifying multi-dimensional embeddings. The shape of the vectors in the embedding.
dim
property dim: Dim
The dimension of the vectors in the embedding (for a 1d embedding).
Raises: For 0- or >1-dimensional embeddings.
dims
The dimensions of the vectors in the embedding.
vocab_size
property vocab_size: Dim
The vocab size of the embedding.
Indices outside the range of [0, index_size) are illegal.
weight
weight: Tensor
Linear
class max.nn.module_v3.Linear(in_dim, out_dim, *, bias=True)
A unary linear transformation over an input tensor.
Linear is defined as f(x) = x @ W.T + B where W is the weight tensor and B is an optional bias tensor.
If W is not square then the transformation represents a dimensionality change. By convention the weight tensor is stored transposed.
from max.nn.module_v3 import Linear
model = Linear(5, 10)
assert dict(model.parameters) == {
"weight": model.weight, "bias": model.bias
}
result = model(Tensor.ones([5]))
assert result.shape == [10]
Constructs a random linear transformation of the given dimensions.
-
Parameters:
bias
bias: Tensor | Literal[0]
in_dim
property in_dim: Dim
The input dimension for the transformation.
out_dim
property out_dim: Dim
The output dimension for the transformation.
weight
weight: Tensor
Module
class max.nn.module_v3.Module
The core unit of composition for modeling in MAX.
Informally, a Module is a container class. It can contain other Module instances, Tensors (the Module’s “local parameters”) or other arbitrary Python data.
A Module also has a __call__ which applies that Module to some input. In the simplest case this is a function from one Tensor to another Tensor.
Formally Modules form a tree, and subtrees of Modules can be manipulate directly. A Module may also be thought of as a closure, where the parameters form the data of the closure and __call__ is the application of the closure.
- Terms:
- A “child” of a Module is sub-Module stored directly on that Module.
- A “descendent” of a Module is one of its children, or one of their descendents.
- A “parameter” is a Tensor storing data on the Module or one of its descendents.
- The “qualified path” of a descendent is a period-separated string of the names of the child module attributes which lead to that descendent module, for instance child.sub.last.
- The “qualified path” of a parameter is the qualified path of the descendent directly holding that parameter, followed by a final path component for the attribute name of the tensor. For instance weight for a local parameter, or child.sub.last.weight for a descendent’s parameter.
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor | int = 0
def __call__(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
linear = Linear(Tensor.zeros([5, 4]))
print(linear)
print(linear(Tensor.constant([1, 2, 3, 4])))
apply_to_local_parameters()
apply_to_local_parameters(f)
Applies a transformation to each local parameter tensor on the Module.
The transformation is applied in-place, updating the module’s values. It will not be applied to descendent’s parameters.
from max.driver import Accelerator
from max.nn.module_v3 import Linear
model = Linear(2, 3)
model.apply_to_parameters(lambda _, t: t.to(Accelerator())
-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) – The transformation to apply to each local parameter. The transformation takes two arguments, a name and a tensor.
- The name is the attribute name of the parameter on the module
- The tensor is the current value of that parameter The return value of this function is the new value that will replace the value at that name.
-
Return type:
-
None
apply_to_parameters()
apply_to_parameters(f)
Applies a transformation to each parameter tensor on the Module and its descendents.
The transformation is applied in-place, updating the module’s values and those of its descendents.
from max.driver import Accelerator
from max.nn.module_v3 import Linear
model = Linear(2, 3)
model.apply_to_parameters(lambda _, t: t.to(Accelerator())
-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
The transformation to apply to each parameter. The transformation takes two arguments, a name and a tensor:
- The name is the qualified name of the parameter with respect to the module on which apply_to_parameters was called.
- The tensor is the current value of that parameter.
The return value of this function is the new value that will replace the value at that name in the module tree.
-
Return type:
-
None
children
Iterates over the direct child modules of the Module.
-
Yields:
-
- (name, module) pairs, where name is the attribute name of
- the child on the module.
compile()
compile(*input_types)
Compiles the module to a model operating on the given input types.
Example:
from max.dtype import DType
from max.experimental import random
from max.experimental.tensor import Tensor, TensorType, defaults
from max.nn.module_v3 import Linear
linear = Linear(2, 3)
_, device = defaults()
input_type = TensorType(DType.float32, ["batch", 2], device=device)
model = linear.compile(input_type)
print(model(random([3, 2], dtype=DType.float32))
print(model(random([10, 2], dtype=DType.float32))
descendents
Iterates over the Module’s descendent modules.
-
Yields:
-
- (name, module) pairs, where name is the qualified path
- of the descendent with respect to the module.
load_state()
load_state(lookup)
Replaces each parameter in the module and its descendents.
The transformation is applied in-place, updating the module’s values and those of its descendents.
Example:
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Linear
model = Linear(2, 3)
weights = {
"weight": Tensor.zeros([3, 2]),
"bias": Tensor.zeros([3]),
}
model.load_state(weights.__getitem__)
The lookup is defined as a function rather than a dictionary, allowing for functional remapping of names during this process to account for differences in common weight naming and storage conventions.
For instance certain representations may not store weights as transposed, or may need to be quantized, or split out from a shared qkv block, or may just have slightly different names or paths.
This can also be used for instance to provide a default value for initializing LoRA weights.
-
Parameters:
-
lookup (Callable[[str], DLPackArray]) –
The lookup function for each parameter:
- The argument to the lookup function is the qualified name of the parameter with respect to the module on which load_state was called.
- The return value of this function is the new value that will replace the value at that name in the module tree.
load_state_dict()
load_state_dict(state, strict=True)
Replaces each parameter in the module and its descendents.
The transformation is applied in-place, updating the module’s values and those of its descendents.
Example:
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Linear
model = Linear(2, 3)
weights = {
"weight": Tensor.zeros([3, 2]),
"bias": Tensor.zeros([3]),
}
model.load_state(weights)
-
Parameters:
-
- state (Mapping[str, DLPackArray]) – A mapping from qualified name to weight
- strict (bool) – If true, verify that every value in state is loaded at least once.
-
Raises:
-
If strict is set (default) and not all weights in state were loaded. –
-
Return type:
-
None
local_parameters
Iterates over the local parameters of the Module.
-
Yields:
-
- (name, tensor) pairs, where name is the attribute name of
- the tensor on the module.
map_parameters()
map_parameters(f)
Creates a new Module with its parameters transformed by the function.
The transformation is functional rather than in-place. The module is deep-copied; its descendents are also replaced via the same transform without affecting the original module.
from max.driver import Accelerator
from max.nn.module_v3 import Linear
model = Linear(2, 3)
model_on_gpu = model.map_parameters(lambda _, t: t.to(Accelerator())
-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
The transformation to apply to each parameter. The transformation takes two arguments, a name and a tensor:
- The name is the qualified name of the parameter with respect to the module on which map_parameters was called.
- The tensor is the current value of that parameter.
The return value of this function is the new value that will replace the value at that name in the module tree.
-
Returns:
-
A new module tree of the same type resulting from mapping the transformation over all model parameters.
-
Return type:
parameters
Iterates over the parameters of the Module and its descendents.
-
Yields:
-
- (name, tensor) pairs, where name is the qualified path of the
- parameter with respect to the module.
to()
to(device)
Updates the module’s parameters, transferring them to the specified device.
from max.driver import CPU
from max.nn.module_v3 import Linear
model = Linear(2, 3)
model.to(CPU())
Sequential
class max.nn.module_v3.Sequential(*modules)
A Module subclass which holds a sequence of unary Modules.
A unary Module is one whose __call__ method has the signature
def __call__(self, x: Tensor) -> Tensor: ...
Sequential is itself a unary Module. Its __call__ method computes the result of applying each of its child modules in sequence to its input.
The following example will apply a linear transformation up to a dimension of 10, apply a LayerNorm, and then apply a final linear transformation to reduce back to the input dimension of 5.
from max.experimental import Tensor
from max.nn.module_v3 import LayerNorm, Linear, Sequential
model = Sequential(
Linear(5, 10),
LayerNorm(10),
Linear(10, 5),
)
result = model(Tensor.ones([5]))
assert result.shape == [5]
Constructs a sequential from a sequence of modules.
Following PyTorch, Sequential takes its inputs as a variadic rather than an iterable. Use the splat operator (*seq) to make a Sequential from an iterable.
from max.nn.module_v3 import Linear, Sequential
hidden_dims = [5, 10, 15, 20]
model = Sequential(*(
Linear(in_dim, out_dim) for in_dim, out_dim in
zip(hidden_dims, hidden_dims[1:])
))
-
Parameters:
-
modules (Module) – The sequence of contained Modules in the order of desired application.
module_dataclass()
max.nn.module_v3.module_dataclass(cls=None, /, *, repr=False, **kwargs)
Decorate a Module subclass as a dataclass.
module_dataclass`es are regular Python dataclasses and also Modules. Using the builtin `dataclass decorator works fine, but will override Module’s __repr__, which may lead to a degraded usage experience when debugging and printing modules.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!