Python module
module
Foundational Module base class and
decorators for building neural network modules.
Base classes and decorators for building neural network modules in MAX.
Module
class max.experimental.nn.module.Module
The core unit of composition for modeling in MAX.
Informally, a Module is a container class. It can contain
other Module instances, tensors (the Module’s “local parameters”)
or other arbitrary Python data.
A Module also has a forward() method which defines how the Module
computes its output. In the simplest case this is a function from one tensor
to another tensor. Users call the module using __call__() which internally
invokes forward().
Formally modules form a tree, and subtrees of modules can be manipulated
directly. A Module may also be thought of as a closure, where the parameters
form the data of the closure and forward() is the application of the closure.
Users who do not use a Python type checker, or use lax settings for their
type checker, may inherit from Module without parameters. Users who use
a type checker with stricter settings (including MAX internal code) should
specify explicit types for full type checking:
class Linear(Module[[Tensor], Tensor]):
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.biasTerminology:
- A “child” of a
Moduleis a sub-Modulestored directly on thatModule. - A “descendant” of a
Moduleis one of its children, or one of their descendants. - A “parameter” is a tensor storing data on the
Moduleor one of its descendants. - The “qualified path” of a descendant is a period-separated string
of the names of the child module attributes which lead to that
descendant module, for instance
child.sub.last. - The “qualified path” of a parameter is the qualified path of the
descendant directly holding that parameter, followed by a final
path component for the attribute name of the tensor.
For instance
weightfor a local parameter, orchild.sub.last.weightfor a descendant’s parameter.
from max.experimental.tensor import Tensor
from max.experimental.nn import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor | int = 0
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
linear = Linear(Tensor.zeros([5, 4]))
print(linear)
print(linear(Tensor([1, 2, 3, 4])))Device placement:
MAX uses a compiled graph model that separates weight storage from computation placement. Understanding this distinction is essential for running models on GPU.
to() is the single pre-compilation entry point for device placement.
It moves all weight tensors to the target device and records it on the
module via the device property. input_types() implementations
should reference self.device when constructing
TensorType objects, so a single to() call drives
both weight placement and computation placement:
from max.driver import Accelerator
from max.experimental.nn import Linear
model = Linear(10, 5)
model.to(Accelerator()) # sets device, moves weights
compiled = model.compile(*model.input_types()) # computation runs on GPUFor CPU (the default), calling to() is optional — the device
property defaults to CPU:
model = Linear(10, 5)
compiled = model.compile(*model.input_types()) # runs on CPUBecause device is tracked per-module instance, sub-modules can be
placed on different devices independently:
encoder.to(Accelerator(0))
decoder.to(Accelerator(1))For graph-level tensor routing inside forward() (e.g., pulling an
activation back to CPU at the end of the graph), use
transfer_to() or to()
instead — those insert transfer nodes into the compiled graph and are
unrelated to pre-compilation weight placement.
| API | When it runs | What it moves |
|---|---|---|
Module.to(device) | Python host, before compile() | Stored weight tensors; records module.device |
ops.transfer_to(x, d) / TensorValue.to(d) | Graph execution time (inside forward()) | Activation tensors within the compiled graph |
Tensor.to(device) | Eager runtime (outside a graph) | Concrete eager tensors (e.g., staging inputs) |
apply_to_local_parameters()
apply_to_local_parameters(f)
Applies a transformation to each local parameter tensor on the Module.
The transformation is applied in-place, updating the module’s values. It will not be applied to descendant’s parameters.
For example:
from max.driver import Accelerator
from max.experimental.nn import Linear
model = Linear(2, 3)
model.apply_to_parameters(lambda _, t: t.to(Accelerator()))-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
The transformation to apply to each local parameter. The transformation takes two arguments, a name and a tensor:
- The name is the attribute name of the parameter on the module.
- The tensor is the current value of that parameter.
The return value of this function is the new value that will replace the value at that name.
-
Return type:
-
None
apply_to_parameters()
apply_to_parameters(f)
Applies a transformation to all parameters in the module hierarchy.
This method traverses the module tree and applies the transformation function to each parameter in-place, updating both the current module’s parameters and all nested sub-module parameters. The transformation receives the parameter’s qualified name (dot-separated path) and current tensor value.
Transfer all parameters to accelerator:
from max.driver import Accelerator
from max.experimental.tensor import Tensor
from max.experimental.nn import Module, module_dataclass, Linear
@module_dataclass
class MLP(Module):
fc1: Linear
fc2: Linear
def forward(self, x: Tensor) -> Tensor:
return self.fc2(self.fc1(x))
model = MLP(
fc1=Linear(10, 20),
fc2=Linear(20, 5)
)
model.apply_to_parameters(lambda name, t: t.to(Accelerator()))-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
Transformation function taking
(name, tensor)and returning the transformed tensor. Parameters:name(str): Qualified dot-separated path of the parameter (e.g.,"fc1.weight","encoder.layer2.bias")tensor(Tensor): Current value of the parameter
Returns the new tensor value to replace the parameter.
-
Return type:
-
None
children
Iterates over the direct child modules of the Module.
-
Yields:
-
(name, module)pairs, wherenameis the attribute name of the child on the module.
compile()
compile(*input_types, weights=None)
Compiles the module to an optimized executable through graph tracing.
This method performs symbolic tracing of the module’s forward method
to construct a MAX Graph, which is then compiled and optimized for
efficient execution on CPU, GPU, or other accelerators.
The compilation process:
- Creates symbolic
Tensorinstances based on provided type specifications - Executes
forwardwith symbolic tensors to record operations - Constructs a
Graphrepresenting the computation - Includes all module parameters as weights in the graph
- Compiles and optimizes the graph for target hardware
- Returns an executable function with the same signature as
forward
The input type specifications must match the signature of forward.
Use positional arguments for positional parameters.
Device placement: The canonical pattern is to call to()
before compile. to() sets device, moves all weights
to that device, and causes input_types() to return
TensorType objects annotated with that device. This
means a single to() call drives both weight placement and
computation placement:
from max.driver import Accelerator
from max.experimental.nn import Linear
model = Linear(10, 5)
model.to(Accelerator()) # sets device, moves weights to GPU
# input_types() uses self.device — computation runs on GPU:
compiled = model.compile(*model.input_types())Basic compilation with fixed shapes:
from max.dtype import DType
from max.experimental.tensor import Tensor, TensorType, defaults
from max.experimental.nn import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
linear = Linear(
weight=Tensor.zeros([10, 5]),
bias=Tensor.zeros([10])
)
# Compile with fixed input shape
_, device = defaults()
input_type = TensorType(DType.float32, [3, 5], device=device)
model = linear.compile(input_type)
# Execute compiled model
input_data = Tensor.ones([3, 5], dtype=DType.float32)
result = model(input_data)
print(result)-
Parameters:
-
- *input_types (Type[Any]) – Type specifications for each positional argument to
forward. Must match the number and order of arguments. Each should be amax.graph.Type(typicallyTensorType) describing the shape and dtype. Thedevicefield on eachTensorTypedetermines where activations are computed; useto()to set this consistently across weights and inputs. - weights (Mapping[str, DLPackArray] | None) – Mapping of parameter names to weight data. Weights should be on CPU and will be transferred to the target device as part of model initialization. If not passed, the model’s parameters will be used as the weights.
- *input_types (Type[Any]) – Type specifications for each positional argument to
-
Returns:
-
Callable[…, Any] A compiled executable function with the same signature as
forward. This function runs the optimized graph and returns results with the same structure asforward(singleTensoror tuple of tensors). -
Raises:
-
- TypeError – If input types don’t match
forwardsignature or if operations inforwardcannot be traced. - RuntimeError – If graph construction fails due to incompatible operations or parameter access issues.
- TypeError – If input types don’t match
-
Return type:
descendants
property descendants: Iterable[tuple[str, Module[..., Any]]]
Iterates over the Module’s descendant modules.
-
Yields:
-
(name, module)pairs, wherenameis the qualified path of the descendant with respect to the module.
device
property device: Device
The canonical device for this module’s weights and computation.
Set by calling to() or by assigning self.device in a
subclass __init__. When neither has been called the property
returns CPU as a safe default so that modules
without an explicit device placement still compile and run on CPU.
input_types() implementations should reference self.device
when constructing TensorType objects so that a
single to() call drives both weight placement and computation
placement.
from max.driver import Accelerator
from max.experimental.nn import Linear
model = Linear(2, 3)
print(model.device) # CPU() — CPU default
model.to(Accelerator())
print(model.device) # Accelerator(id=0)forward()
forward(*args, **kwargs)
Defines the computation performed by the module.
Users must override this method in their subclass to define the module’s computation.
-
Parameters:
-
- *args (~_P) – Positional arguments for the computation.
- **kwargs (~_P) – Keyword arguments for the computation.
-
Returns:
-
The result of applying the module to the input.
-
Raises:
-
NotImplementedError – If the subclass does not override this method.
-
Return type:
-
_R
load_state()
load_state(lookup)
Replaces each parameter in the module and its descendants.
The transformation is applied in-place, updating the module’s values and those of its descendants.
For example, if we have a model with two parameters, weight and
bias, we can load the state of the model from a dictionary with the
following code:
from max.experimental.tensor import Tensor
from max.experimental.nn import Linear
model = Linear(2, 3)
weights = {
"weight": Tensor.zeros([3, 2]),
"bias": Tensor.zeros([3]),
}
model.load_state(lambda name, _: weights[name])The lookup is defined as a function rather than a dictionary, allowing for functional remapping of names during this process to account for differences in common weight naming and storage conventions.
For instance, certain representations may not store weights as transposed, or may need to be quantized, or split out from a shared qkv block, or may just have slightly different names or paths.
This can also be used for instance to provide a default value for initializing LoRA weights.
-
Parameters:
-
lookup (Callable[[str, Tensor], DLPackArray]) –
The lookup function for each parameter:
- The first argument is the qualified name of the parameter
with respect to the module on which
load_state()was called. - The second argument is the existing tensor value.
- The return value of this function is the new value that will replace the value at that name in the module tree.
- The first argument is the qualified name of the parameter
with respect to the module on which
load_state_dict()
load_state_dict(state, strict=True)
Loads parameter values from a dictionary into the module hierarchy.
This method updates all module parameters in-place by loading values from
the provided state dictionary. The dictionary maps qualified parameter names
(dot-separated paths like "fc1.weight") to tensor values.
The strict mode (default) ensures all weights in the dictionary are
actually used, catching errors from mismatched architectures or incorrect
weight names.
For example, the following loads weights from a dictionary into a model:
from max.experimental.tensor import Tensor
from max.experimental.nn import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
model = Linear(
weight=Tensor.zeros([10, 5]),
bias=Tensor.zeros([10])
)
# Load weights from dictionary
weights = {
"weight": Tensor.zeros([10, 5]),
"bias": Tensor.zeros([10]),
}
model.load_state(lambda name, _: weights[name])-
Parameters:
-
- state (Mapping[str, DLPackArray]) – Dictionary mapping qualified parameter names to tensor values.
Keys should match the names from
Module.parametersproperty. Values should be DLPack-compatible arrays orTensorobjects. Their shapes and dtypes must match the existing parameters with the corresponding name, but they may be on a different device. In the case that the new value has a different device, it will be copied to the same device as the existing value, and the parameter will be set to the new copy. - strict (bool) – If
True(default), verify that all keys instateare used (i.e., match actual parameters). IfFalse, silently ignore extra keys that don’t match any parameters.
- state (Mapping[str, DLPackArray]) – Dictionary mapping qualified parameter names to tensor values.
Keys should match the names from
-
Raises:
-
- ValueError – If
strict=Trueand some weights instatedon’t match any model parameters (indicates architecture mismatch or incorrect weight names). - ValueError – If a loaded tensor has a different dtype or shape than the existing parameter.
- KeyError – If a required parameter name in the model is missing from
state(regardless ofstrictsetting).
- ValueError – If
-
Return type:
-
None
local_parameters
Iterates over the local parameters of the Module.
-
Yields:
-
(name, tensor)pairs, wherenameis the attribute name of the tensor on the module.
map_parameters()
map_parameters(f)
Creates a new Module with its parameters transformed by the function.
The transformation is functional rather than in-place. The module is deep-copied; its descendants are also replaced via the same transform without affecting the original module.
For example:
from max.driver import Accelerator
from max.experimental.nn import Linear
model = Linear(2, 3)
model_on_gpu = model.map_parameters(lambda _, t: t.to(Accelerator()))-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
The transformation to apply to each parameter. The transformation takes two arguments, a name and a tensor:
- The name is the qualified name of the parameter
with respect to the module on which
map_parameters()was called. - The tensor is the current value of that parameter.
The return value of this function is the new value that will replace the value at that name in the module tree.
- The name is the qualified name of the parameter
with respect to the module on which
-
Returns:
-
A new module tree of the same type resulting from mapping the transformation over all model parameters.
-
Return type:
parameters
Iterates over all parameters in this module and its sub-modules.
This property performs a depth-first traversal of the module hierarchy,
yielding each parameter tensor with its qualified name. The qualified name
uses dot-notation to represent the module tree structure (e.g.,
"encoder.layer1.weight").
Parameters are yielded in depth-first order: first the current module’s direct parameters, then recursively each sub-module’s parameters.
Counting total parameters:
from max.experimental.tensor import Tensor
from max.experimental.nn import Module, module_dataclass
from max.experimental.nn import Linear
@module_dataclass
class MLP(Module):
fc1: Linear
fc2: Linear
def forward(self, x: Tensor) -> Tensor:
return self.fc2(self.fc1(x))
model = MLP(
fc1=Linear(10, 20),
fc2=Linear(20, 5)
)
# Count parameters
total_params = sum(
param.num_elements()
for name, param in model.parameters
)
print(f"Total parameters: {total_params}")-
Yields:
-
(name, parameter)tuples wherenameis the dot-separated qualified path of the parameter andparameteris theTensor.
to()
to(device)
Sets this module’s device and transfers all weight parameters to it.
This is the single entry point for device placement. After calling
to(device), both weight storage and input_types() reflect the
target device, so compile(*self.input_types()) works correctly
without any additional device configuration:
from max.driver import Accelerator
from max.experimental.nn import Linear
from max.graph import TensorType
from max.dtype import DType
model = Linear(2, 3)
model.to(Accelerator())
# input_types() uses self.device, so computation runs on GPU:
compiled = model.compile(*model.input_types())Unlike PyTorch’s eager mode where weights and computation are
inseparable, MAX uses a compiled graph model. to() handles the
weight side; input_types() implementations use self.device
to handle the computation side. Together they form one coherent
mechanism.
For graph-level tensor routing at execution time (inside
forward()), use transfer_to() or
to() instead — those insert transfer ops
into the compiled graph and are unrelated to pre-compilation device
placement.
-
Parameters:
-
device (Device) – The device to which all model parameters will be transferred and which
input_types()will use as the computation device. -
Returns:
-
A reference to the model. The transfer is applied mutably; the module’s
deviceproperty and all internal parameters are updated in place. -
Return type:
module_dataclass()
max.experimental.nn.module.module_dataclass(cls=None, /, *, repr=False, **kwargs)
Converts a class into a MAX module with automatic parameter tracking.
This decorator enables a regular Python class to function as a Module,
providing automatic discovery and registration of parameters (Tensor fields)
and nested modules. The decorated class gains all capabilities of Module,
including parameter iteration, graph compilation via Module.compile(),
and hierarchical module composition.
The decorator applies Python’s @dataclass decorator internally while
preserving Module’s specialized __repr__ method for better
debugging experience when printing module structures.
from max.experimental.nn import Module, Linear, module_dataclass
from max.experimental.tensor import Tensor
from max.experimental import functional as F
@module_dataclass
class MLP(Module):
fc1: Linear
fc2: Linear
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
# Create module with automatic parameter tracking
mlp = MLP(
fc1=Linear(128, 256),
fc2=Linear(256, 128)
)
# All parameters are automatically tracked
print(dict(mlp.parameters).keys())
# {'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'}
# Use the module
x = Tensor.randn([4, 128])
output = mlp(x)
print(output.shape) # (4, 128)-
Parameters:
-
- cls (type[Module[..., Any]] | None) – The class to decorate. Must define a
forwardmethod. WhenNone, returns a decorator function (supports using@module_dataclasswith or without parentheses). - repr (bool) – If
True, use dataclass’s default__repr__instead ofModule’s rich representation. Defaults toFalse. - **kwargs – Additional keyword arguments forwarded to Python’s
@dataclassdecorator (e.g.,frozen,eq).
- cls (type[Module[..., Any]] | None) – The class to decorate. Must define a
-
Returns:
-
The decorated class as a
Modulesubclass with automatic parameter tracking and graph compilation capabilities. WhenclsisNone, returns a decorator function.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!