Skip to main content

Python function

module_dataclass

module_dataclass()

max.experimental.nn.module_dataclass(cls=None, /, *, repr=False, **kwargs)

source

Converts a class into a MAX module with automatic parameter tracking.

This decorator enables a regular Python class to function as a Module, providing automatic discovery and registration of parameters (Tensor fields) and nested modules. The decorated class gains all capabilities of Module, including parameter iteration, graph compilation via Module.compile(), and hierarchical module composition.

The decorator applies Python’s @dataclass decorator internally while preserving Module’s specialized __repr__ method for better debugging experience when printing module structures.

from max.experimental.nn import Module, Linear, module_dataclass
from max.experimental.tensor import Tensor
from max.experimental import functional as F

@module_dataclass
class MLP(Module):
    fc1: Linear
    fc2: Linear

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

mlp = MLP(
    fc1=Linear(128, 256),
    fc2=Linear(256, 128)
)

# All parameters are automatically tracked
print(dict(mlp.parameters).keys())
# {'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'}

# Use the module
x = Tensor.randn([4, 128])
output = mlp(x)
print(output.shape)  # (4, 128)

Parameters:

  • cls (type[Module[..., Any]] | None) – The class to decorate. Must define a forward method. When None, returns a decorator function (supports using @module_dataclass with or without parentheses).
  • repr (bool) – If True, use dataclass’s default __repr__ instead of Module’s rich representation. Defaults to False.
  • **kwargs – Additional keyword arguments forwarded to Python’s @dataclass decorator (e.g., frozen, eq).

Returns:

The decorated class as a Module subclass with automatic parameter tracking and graph compilation capabilities. When cls is None, returns a decorator function.