Skip to main content

Build a model with MAX Graph API

If you want to take full advantage of MAX Engine for high-performance inference, you can use the MAX Graph APIs to build your model in Mojo. It’s not a machine learning framework (you can’t train models written with the Graph APIs, and they don’t offer high-level abstractions like neural network layers)—the Graph APIs allow you to build, import, and tweak the data flow graphs that execute in MAX Engine (you must load the model weights yourself).

We built the MAX Graph API because, although MAX Engine can execute models from frameworks such as PyTorch and TensorFlow faster than the default runtimes, sometimes the high-level abstractions leave performance on the table. With the MAX Graph API, we want to give you full control over the graphs that the engine executes.

The MAX Graph API allows you to directly manipulate the low-level graph in Mojo—a programming language that provides low-level systems programming for high performance code, but with the usability of Python. By building your graph in Mojo, you can build highly performant graphs, but with much less code that’s more readable compared to other performance-based graph libraries in C or C++.

To get started with the Graph API, let’s look at the typical "hello world" model: a feed-forward network to classify numbers in the MNIST dataset.


For a more exciting model built with the MAX Graph API, checkout our Llama2 Mojo example.

Create a Graph

To get started, create a Graph. A Graph is like a function: it has a name, takes arguments, runs computations and returns some values.

A Graph is created inside a Module, which is conceptually is just a package that holds together related Graph objects. The Module direct equivalent of MLIR’s ModuleOp, similar to the ONNX ModelProto.

from max import graph

fn main() raises:
var m = graph.Module()

var g = m.graph(
in_types=graph.MOTensor(DType.float32, 1, 28, 28, 1),
out_types=graph.MOTensor(DType.float32, 1, 10),


The Module.graph() function has three notable arguments:

  • The graph’s name, which is just an arbitrary string.
  • The input types, which denote the types for the graph input arguments.
  • The output types, which denote the types for the graph return values.

To describe types, the Graph API uses a set of specialized “type” objects, like MOTensor and MOList . These represent exactly what their name suggests, and are built specifically to describe the data types that you can use in a MAX graph.

In the example above, the graph has a single argument of type “tensor”: MOTensor(DType.float32, 1, 28, 28, 1) . The tensor is of shape 1x28x28x1 and float32 data type.

Similarly, the graph has a single return value, which is a 1x10 , float32 tensor.

To specify multiple arguments or return types, or none at all, pass a max.TypeTuple instead.

Add ops to the Graph

After creating a Graph, you can add ops to it. All Graph operations flow from either graph inputs or constant values, and you can create new values by passing existing values through operation functions.

For example, here’s how to add a constant to the graph with the Graph.constant() helper:

from tensor import Tensor, TensorShape

var constant_value = Tensor[DType.float32](128, 10)
var cst = g.constant(constant_value)
var cst_0 = g.constant(
TensorShape(1, 10),

The cst variable above has type Symbol, and represents the symbolic result value of the g.constant op. Because we are building a graph, computations don’t return actual values, but return symbolic handles. These let you identify the output or input of an op, but they have no concrete value, because values are known only at runtime. Much like a value, a Symbol is passed as input to another op to form increasingly complicated expressions in the Graph.

Ultimately, there are several ways to create ops:

Add more ops to the Graph

Let’s add a couple more constants to the Graph to complete our MNIST classification model:

var cst_1_data = Tensor[DType.float32](784, 128)
var cst_1 = g.constant(cst_1_data)

var cst_2_data = Tensor[DType.float32](1, 128)
var cst_2 = g.constant(cst_2_data)

Then let’s add a reshape op:

var p1 = g[0].reshape(1, 784)

Notice that we referenced the op using g[0], which denotes the Graph's first argument. Arguments are also Symbol objects, so you can use them when in the same way as the output of a regular op. This particular reshape is created using Symbol.reshape().

Add one more op, this time using Symbol.__matmul__:

var p2 = p1 @ cst_1

Add a few more ops to complete our MNIST model, using a mix operators and op helpers:

from max.graph import ops

var p3 = ops.add(p2, cst_2)
var p4 = ops.elementwise.relu(p3)
var p5 = p4 @ cst
var p6 = p5 + cst_0

Return results from the Graph

Every Graph has a special op that denotes the return values. Create it using Graph.output():


That’s it!

We now have a complete Graph and can see it with the str() function:


This prints the MLIR that represents the MAX graph.

Use external weights

Normally, you run the graph with pre-trained weights loaded from a source. The most practical way to to that is to use standard Mojo APIs, like Tensor.load() , or implement your own data loading code in Mojo and use one of the other Tensor constructors, such as the one that receives a pointer to a data buffer.

For an example of the latter strategy, check out the Llama 2 MAX Graph example.

Execute the Graph

Once the Graph is built, you can execute it using the MAX Engine Mojo API:

from max.engine import InferenceSession, Model, NamedTensor

var session = InferenceSession()
var compiled_model = session.load_model(m)

# Create a test input image, fill it with flat values
var x = Tensor[DType.float32](1, 28, 28, 1)

var result_map = compiled_model.execute(NamedTensor("input0", x ^))
var predictions = result_map.get[DType.float32]("output0")

That's all for now!

This has been just a basic introduction to the MAX Graph API, and there is still much more to come! To stay up to date on what's coming, sign up for our newsletter. Also, talk to other MAX developers, ask questions, and share feedback on Discord and GitHub.