Skip to main content

Build a model graph with Module

A Module defines your model's layers, weights, and forward computation. A Graph captures that computation as a static data structure the compiler can optimize. You need both to build a runnable model.

You'll use modules to build a model definition from the ground up, but you'll also use them to extend an existing MAX-supported model with custom layers. This page shows you how to compose built-in modules, write custom modules with explicit weights, load checkpoint data, and construct a graph from the result. By the end, you'll understand how to build a complete model definition and verify its outputs locally.

Key concepts

If you have experience with PyTorch and torch.nn.Module, MAX's Module class will look familiar. In both cases, a module defines your model's layers, weights, and computation. The key difference is that calling a PyTorch module runs computation immediately on real tensors, while calling a MAX module records the computation into a graph that must be compiled before it can execute. As such, you must pass modules to the graph constructor, as you'll see below.

As a result, you define MAX modules using symbolic value placeholders that represent tensors and weights rather than working with real data. This programming pattern relies on the following key APIs:

  • Graph: A data structure that represents your model's computation. It stores what operations to perform on tensors and in what order. You can only run the computation after you compile the graph into an executable model.
  • TensorValue: A symbolic value that stands in for a tensor during graph construction.
  • Weight: A symbolic value that stands in for a learned model parameter during graph construction. Unlike TensorValue, a Weight maps to known data loaded from a checkpoint.
  • ops: Operations that consume and return TensorValue objects.

Define and compose modules

A module declares which layers and weights it contains, then defines how data flows through them. In most cases, you can compose a module entirely with prebuilt modules, such as Linear, Embedding, RMSNorm, LayerNorm, and others.

Here's a simple module that composes two built-in Linear modules:

from max import nn
from max.graph import ops, DeviceRef
from max.dtype import DType

class FeedForward(nn.Module):
    """Two linear projections with SiLU activation."""

    def __init__(self, dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim, DType.float32, DeviceRef.GPU())
        self.fc2 = nn.Linear(hidden_dim, dim, DType.float32, DeviceRef.GPU())

    def __call__(self, x):
        return self.fc2(ops.silu(self.fc1(x)))

Unlike PyTorch, the built-in Linear module doesn't include a bias by default. If your checkpoint includes bias terms, pass has_bias=True to the constructor.

The __init__() method

The __init__() method declares the module's structure by assigning sub-modules and weights to class (self) attributes. The __call__() method can then use these attributes.

Each built-in module has its own constructor signature: Linear takes input and output dimensions, Embedding takes vocabulary size and hidden dimension, RMSNorm takes a single normalization dimension, and so on. When you declare a built-in module in the __init__() method, consult the max.nn API reference for the specific parameters each module expects.

The __call__() method

The __call__() method defines the module's computation. Inside this method, you call sub-modules declared in __init__(), apply operations, and return the result. If you're familiar with PyTorch, __call__() serves the same role as forward().

However, calling a MAX module (which invokes __call__()) records operations into a graph rather than executing them on real data. This means you can only call a module from inside a Graph context. We'll explore this more in the section below, Turn a module into a graph.

Chain multiple modules

When you want to stack multiple modules in a straight chain with no intermediate operations, use Sequential. It stores an ordered sequence of modules and chains their __call__() methods automatically. The output of each module feeds directly into the next. For example:

layers = nn.Sequential([
    nn.Linear(5, 10, DType.float32, DeviceRef.GPU()),
    nn.Linear(10, 5, DType.float32, DeviceRef.GPU()),
])

Sequential is especially useful when the number of layers is a runtime parameter. Here's how you use it to compose a block of multiple FeedForward modules:

class FeedForwardBlock(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, n_layers: int) -> None:
        super().__init__()
        self.layers = nn.Sequential([
            FeedForward(dim, hidden_dim) for _ in range(n_layers)
        ])

    def __call__(self, x):
        return self.layers(x)

Calling self.layers(x) passes x through each FeedForward module in order and returns the final output.

Write a module with explicit weights

Built-in modules like Linear create their own Weight objects internally. You can also declare weights yourself, such as when your architecture needs a layer that the built-in modules don't provide or when you need to add a parameter like a scale or gate to a module. This is analogous to creating a torch.nn.Module with torch.nn.Parameter objects in PyTorch.

Here's a module that combines a Linear layer with a learnable per-output scale using an explicit Weight:

class ScaledLinear(nn.Module):
    """Built-in Linear with a learnable output scale."""

    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim, DType.float32, DeviceRef.GPU())
        self.scale = Weight(
            "scale", DType.float32, [out_dim], DeviceRef.GPU()
        )

    def __call__(self, x):
        return self.linear(x) * self.scale

Load weights into the module

Once your module represents all the layers in your model, you can load the pretrained weights using load_state_dict(). This method accepts a dictionary that maps weight names to tensor values.

# Instantiate the module
model = FeedForwardBlock(dim=512, hidden_dim=1024, n_layers=4)

# Load weights into the module
model.load_state_dict(my_state_dict)

The load_state_dict() method does two things:

  • Finalizes weight names. In a module, each Weight object is created with a local name (such as "weight"). The method walks the module tree and updates each Weight with its fully-qualified name (such as "layers.0.fc1.weight"). If you skip this step and pass the module directly to Graph, weights from different modules that share the same local name collide and graph construction raises an error.

  • Stores the dictionary. The method stores the dictionary in the module's internal registry so you can pass it to the compiler later on.

To load weights from a checkpoint file in safetensors or GGUF format, you'll need a weight adapter that translates the checkpoint's key names into the names your module expects. Learn more in the Map weight names section of the Model bring-up workflow.

If you want to see your module weights' fully-qualified names to make sure you're writing your weight adapter correctly, run the following after loading weights:

print(model.state_dict().keys())
# dict_keys(["layers.0.fc1.weight", "layers.0.fc2.weight",
#            "layers.1.fc1.weight", ... "layers.3.fc2.weight"])

These hierarchical names exist because of how Module works. When you assign a Module instance as an attribute of another Module, the parent automatically registers it as a sub-module. Weight objects in each sub-module are also registered. This registration is what enables Module to discover every weight in the hierarchy and assemble fully-qualified names, mirroring how PyTorch's named_parameters() builds hierarchical names.

Turn a module into a graph

The final step is to construct a graph. When the Graph constructor receives a callable as its forward argument (its second positional parameter), it traces that callable into a static computation graph. Because Module defines __call__(), you can pass a module instance directly to the constructor.

from max.graph import Graph, TensorType

graph = Graph(
    "my_model",
    model,
    input_types=[TensorType(DType.float32, shape=[1, 512], device=DeviceRef.GPU())],
)

Here's what happens when you instantiate a Graph:

  1. The constructor opens a Graph context, or a scope that establishes an active graph.

  2. The constructor creates symbolic TensorValue objects from input_types that define the shape, dtype, and device of the inputs.

  3. The constructor passes the TensorValue objects to the callable object. In this case, the callable is the model module.

  4. Each operation inside __call__() records itself as a node in the active graph. You can only call a module from inside a Graph context, because there is no active graph to record into otherwise.

  5. The constructor records the return value of __call__() as the graph output.

As you can see clearly here, the __call__() method does not perform a forward pass on tensors. It only records the operations into a graph. To perform a forward pass, you must first load the Graph instance into an inference session, which compiles the graph with hardware-specific optimizations.

Run the model locally

Before you connect your model to the greater MAX serving framework, you may want to inspect the results. Compile and run the model graph locally to check its outputs.

from max.engine import InferenceSession
from max.driver import GPU

session = InferenceSession(devices=[GPU()])
compiled_model = session.load(graph, weights_registry=model.state_dict())
result = compiled_model(input_data)

print(result[0].to_numpy())

InferenceSession.load() compiles a graph into an executable Model. It requires a graph and a weights_registry, or a dictionary that maps weight names to tensor data.

You should have already passed this dictionary to load_state_dict(), which stored the dictionary inside the Module instance. Reuse it by passing model.state_dict() as the weights_registry argument.

Calling the compiled Model runs execute() and returns a list of output buffers. Convert each buffer to a NumPy array with to_numpy() to inspect the results. Output post-processing such as decoding happens elsewhere in the inference pipeline.

Verify layer correctness

Some MAX ops also use different conventions than other frameworks, such as tensor layout ordering. To verify that individual layers in your MAX implementation produce the same output as your original trained model, you can call print() on a TensorValue in your Module. This op records a print node into the graph that fires when your compiled model runs.

class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim, DType.float32, DeviceRef.GPU())
        self.fc2 = nn.Linear(hidden_dim, dim, DType.float32, DeviceRef.GPU())

    def __call__(self, x):
        hidden = ops.silu(self.fc1(x))
        hidden.print("after_silu")       # .print() calls ops.print()
        return self.fc2(hidden)

If you instead call Python's print(), you'll see an internal graph representation rather than tensor data.

Next steps

At this point, you understand how to build a small model graph from scratch. To learn how to customize an existing MAX model graph, see the Model bring-up workflow.

Once you have a model graph, you still need to connect it to an inference pipeline that handles tokenization, batching, and request routing. Learn more in Model pipelines.

Was this page helpful?