MAX Engine is a next-generation compiler and runtime library for running AI inference. With support for PyTorch (TorchScript), ONNX, and native Mojo models, it delivers low-latency, high-throughput inference on a wide range of hardware to accelerate your entire AI workload. As highlighted in the recent MAX version 24.3 release, the MAX platform enables users to fully leverage the capabilities of the MAX Engine by creating bespoke inference models using the MAX Graph APIs. The Graph API offers a low-level programming interface for constructing high-performance symbolic computation graphs in Mojo. This interface provides a uniform representation of symbolic values and a suite of operators that process these symbols to construct the entire graph.

In this tutorial, we guide you step-by-step how to use the MAX Graph API. In a nutshell, working with MAX Graph API involves three main steps:

- Building and verifying the graph.
- Creating an inference session and compiling the graph.
- Executing the graph with input(s) and retrieving the output(s).

We begin by creating two straightforward graphs for addition and matrix multiplication in Mojo, demonstrating how to compile and execute these graphs. Then we proceed to implement a two-layer feedforward neural network with ReLU activation for inference on MNIST data, comparing the accuracy to a PyTorch implementation. Additionally, we implement ReLU6 as a custom operator and use the MAX Graph Custom Operator API to substitute ReLU and ensuring the accuracy aligns with the PyTorch model.

The code for this tutorial is on
GitHub.
The MAX version for this tutorial is the **nightly** (tested version *max 2024.7.1505 (ba28802f)*) for now.

If you experience any issues in this tutorial, please let us know on GitHub.

## Set up MAX

If you don't have MAX yet, follow the MAX install guide.

If you're new to the Mojo language, you can learn the basics in the Introduction to Mojo.

If you have any questions along the way, ask them on our Discord channel or in the GitHub discussions on the Mojo repo and MAX repo. Should you encounter any issues, we recommend checking the roadmap and known issues first.

## Build a "Hello, world!" graph

To begin familiarizing ourselves with the Graph API, we start by constructing a simple addition graph. We will verify and compile this graph, and then proceed to execute it.

Below is a straightforward graph that takes two inputs; *input0* and *input1*.
It adds these inputs together and produces *output0* as the output.

### 1. Build the graph

To construct the addition graph, we start by importing the necessary modules.
We then instantiate the
Graph by
specifying two input types of fixed static dimension *1* (we will later see
other types of supported dimensions such as symbolic dimension). Next, we
create a symbolic representation of the addition with the expression *out =
graph[0] + graph[1]*. Here *graph[0]* refers to the first input *input0*
and *graph[1]* to *input1*. This operation adds two inputs together. Finally,
we designate out as the output of the graph by calling *graph.output(out)*.

`from max.graph import Graph, TensorType, Type`

graph = Graph(in_types=List[Type](TensorType(DType.float32, 1), TensorType(DType.float32, 1)))

out = graph[0] + graph[1]

graph.output(out)

print(graph)

We can print the graph to visually confirm its structure. The output should
show the following representation where *rmo* and *mo* are **Modular’s internal
intermediate representations**

`%0 = rmo.add(%arg0, %arg1) : !mo.tensor<[1], f32>, !mo.tensor<[1], f32>`

This line corresponds to the symbolic addition operation *out = graph[0] + graph[1]*.

The subsequent line

`mo.output %0 : !mo.tensor<[1], f32>`

indicates that *%0* has been set as the output of the graph, aligning with the *graph.output(out)* in our code.

The complete graph representation looks like this:

`graph: module {`

mo.graph @graph(%arg0: !mo.tensor<[1], f32>, %arg1: !mo.tensor<[1], f32>) -> !mo.tensor<[1], f32> no_inline {

%0 = rmo.add(%arg0, %arg1) : !mo.tensor<[1], f32>, !mo.tensor<[1], f32>

mo.output %0 : !mo.tensor<[1], f32>

}

}

To programmatically verify the complete graph construction, we use the
*graph.verify()* method. This checks for various structural integrity criteria
such as ensuring there are no cycles within the graph (acyclicity) which would
indicate recursion or feedback loops that can not be part of the dataflow
graph. For more details, check out the official documentation on the
verify
method.

### 2. Create inference session, load and compile the graph

With our graph now verified and ready, the next step involves creating an
inference
session
instance, loading the graph into this *session* and compiling the graph into a
model
instance. We also print the input names to use when executing the model.

`from max import engine`

session = engine.InferenceSession()

model = session.load(graph)

print("input names are:")

for input_name in model.get_model_input_names():

# Mojo lesson: `[]` dereferences in Mojo as `input_name` is of `Reference` type

print(input_name[])

which outputs

`input names are:`

input0

input1

Verifying input names *input0* and *input1* is crucial for correctly executing
the model in the subsequent section.

### 3. Execute the graph/model with inputs

To execute the graph, we first create two input
tensors in Mojo,
specifying their names and values in the
execute
method. The result from the execution are returned as
TensorMap,
from which we can retrieve the value of *output0* via the
get
method as follows

`from tensor import Tensor`

print("set some input values:")

input0 = Tensor[DType.float32](List[Float32](1.0))

print("input0:", input0)

input1 = Tensor[DType.float32](List[Float32](1.0))

print("input1:", input1)

print("obtain the result using `get`:")

# Mojo lesson: here the `^` in `input0^` passes the ownership and ends the lifetime of `input0`

ret = model.execute("input0", input0^, "input1", input1^)

print("result:", ret.get[DType.float32]("output0"))

The outputs are printed as follows

`set some input values:`

input0: Tensor([[1.0]], dtype=float32, shape=1)

input1: Tensor([[1.0]], dtype=float32, shape=1)

obtain the result using `get`:

result: Tensor([[2.0]], dtype=float32, shape=1)

Now, let’s explore our second example.

## Build a matmul graph

In this example, we create a graph specifically for performing matrix
multiplication (*matmul*) by a constant symbol which we will use further along
in the next section. This type of graph is particularly important as it
demonstrates how constant symbols, representing trained and fixed weights in a
neural network, can be utilized. This concept will be expanded upon in
subsequent sections.

The setup for this *matmul* graph follows the same foundational steps as our
initial example but includes some critical additions:

- We introduce a
symbolic dimension
*m*to represent*m x 2* - The use
*graph.constant*to create a constant symbol, crucial for maintaining static values

Here's how we compile and execute the graph to accommodate varying input tensor sizes at runtime:

from max.graph import Graph, TensorType

from max.tensor import Tensor, TensorShape

from random import seed

from max.engine import InferenceSession

graph = Graph(TensorType(DType.float32, "m", 2))

# create a constant tensor value to later create a graph constant symbol

constant_value = Tensor[DType.float32](TensorShape(2, 2), 42.0)

print("constant value:", constant_value)

# create a constant symbol

constant_symbol = graph.constant(constant_value)

# create a matmul node

mm = graph[0] @ constant_symbol

graph.output(mm)

# verify

graph.verify()

# create session, load and compile the graph

session = InferenceSession()

model = session.load(graph)

# generate random input

seed(42)

input0 = Tensor[DType.float32].randn((2, 2))

print("random 2x2 input0:", input0)

ret = model.execute("input0", input0^)

print("matmul 2x2 result:", ret.get[DType.float32]("output0"))

# with 3 x 2 matrix input

input0 = Tensor[DType.float32].randn((3, 2))

print("random 3x2 input0:", input0)

ret = model.execute("input0", input0^)

print("matmul 3x2 result:", ret.get[DType.float32]("output0"))

Here are the results of *matmul* graph using a constant symbol of *2 x 2*
tensor and a random input tensors of shapes *2 x 2* or *3 x 2* for
demonstration

`constant value: Tensor([[42.0, 42.0],`

[42.0, 42.0]], dtype=float32, shape=2x2)

random 2x2 input0: Tensor([[-1.7141127586364746, 0.057178866118192673],

[0.75628399848937988, -1.6024507284164429]], dtype=float32, shape=2x2)

matmul 2x2 result: Tensor([[-69.591224670410156, -69.591224670410156],

[-35.53900146484375, -35.53900146484375]], dtype=float32, shape=2x2)

random 3x2 input0: Tensor([[1.0167152881622314, -0.10449378937482834],

[-0.27936717867851257, -0.69003057479858398],

[0.80745488405227661, -0.48231619596481323]], dtype=float32, shape=3x2)

matmul 3x2 result: Tensor([[38.313301086425781, 38.313301086425781],

[-40.714706420898438, -40.714706420898438],

[13.655824661254883, 13.655824661254883]], dtype=float32, shape=3x2)

With this foundation, we are ready to explore more advanced applications in the next section of the tutorial.

## Build an MNIST classifier graph

In this section, we demonstrate how to build a two-layer neural network with
ReLU activation using PyTorch, train it on the famous MNIST
data featuring black and white
*28 x 28* pixel images of handwritten digits (*0* to *9* i.e. total of *10*
classes) and then test its accuracy.

Subsequently, we will implement the same model using the MAX Graph API for inference to ensure the accuracy remains consistent.

### 1. Build and train the model in PyTorch

First, to set up, let’s define our neural network in PyTorch:

`import torch.nn as nn`

class Model(nn.Module):

def __init__(self, input_size, hidden_size, num_classes):

super().__init__()

self.fc1 = nn.Linear(input_size, hidden_size)

self.relu = nn.ReLU()

self.fc2 = nn.Linear(hidden_size, num_classes)

def forward(self, x):

x = self.fc1(x)

x = self.relu(x)

x = self.fc2(x)

return x

We can train and test the network as follows (*python mnist.py*)

`loss_fn = nn.CrossEntropyLoss()`

optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)

total_steps = len(train_loader)

for epoch in range(num_epochs):

for i, (images, labels) in enumerate(train_loader):

images = images.reshape(-1, 28 * 28)

outputs = model(images)

loss = loss_fn(outputs, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

if (i+1) % 100 == 0:

print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], Loss: {loss.item():.4f}')

# test

model.eval()

with torch.no_grad():

correct = 0

total = 0

for images, labels in test_loader:

images = images.reshape(-1, 28 * 28)

outputs = model(images)

probs = F.softmax(outputs, dim=1)

predicted = torch.argmax(probs, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct / total} %")

# save weights in numpy binary format

weights = {}

for name, param in model.named_parameters():

weights[name] = param.detach().cpu().numpy()

np.save(f"model_weights.npy", weights)

After training and testing the network, we found the model achieves an accuracy
of *97.31%* on the test dataset.

`Epoch [1/5], Step [100/469], Loss: 0.5384`

Epoch [1/5], Step [200/469], Loss: 0.2288

Epoch [1/5], Step [300/469], Loss: 0.3225

Epoch [1/5], Step [400/469], Loss: 0.2614

Epoch [2/5], Step [100/469], Loss: 0.1049

Epoch [2/5], Step [200/469], Loss: 0.2166

Epoch [2/5], Step [300/469], Loss: 0.2362

Epoch [2/5], Step [400/469], Loss: 0.1472

Epoch [3/5], Step [100/469], Loss: 0.1200

Epoch [3/5], Step [200/469], Loss: 0.1284

Epoch [3/5], Step [300/469], Loss: 0.0726

Epoch [3/5], Step [400/469], Loss: 0.1111

Epoch [4/5], Step [100/469], Loss: 0.0702

Epoch [4/5], Step [200/469], Loss: 0.0650

Epoch [4/5], Step [300/469], Loss: 0.1297

Epoch [4/5], Step [400/469], Loss: 0.1334

Epoch [5/5], Step [100/469], Loss: 0.0265

Epoch [5/5], Step [200/469], Loss: 0.0301

Epoch [5/5], Step [300/469], Loss: 0.1179

Epoch [5/5], Step [400/469], Loss: 0.0424

Accuracy of the network on the 10000 test images: 97.31 %

Next, we implement the PyTorch model in MAX Graph API for inference.

### 2. Build the inference graph with MAX Graph

After training our model and saving its weights, we need to construct an
inference graph and load the weights as constant symbols. Our graph will handle
input dimensions with a symbolic *"batch"* dimension and static *28x28* spatial
dimensions, representing flattened and preprocessed images. We will also
include a softmax operation via
ops.softmax
to compute probabilities directly within the inference graph.

`from max.graph import Graph, TensorType, ops`

from max import engine

def build_mnist_graph(

fc1w: Tensor[DType.float32],

fc1b: Tensor[DType.float32],

fc2w: Tensor[DType.float32],

fc2b: Tensor[DType.float32],

) -> Graph:

# Note: "batch" is a symbolic dim which is known ahead of time vs dynamic dim

graph = Graph(TensorType(DType.float32, "batch", 28 * 28))

# PyTorch linear is defined as: x W^T + b so we need to transpose the weights

fc1 = (graph[0] @ ops.transpose(graph.constant(fc1w), 1, 0)) + graph.constant(fc1b)

relu = ops.relu(fc1)

fc2 = (relu @ ops.transpose(graph.constant(fc2w), 1, 0)) + graph.constant(fc2b)

out = ops.softmax(fc2) # adding explicit softmax for inference prob

graph.output(out)

graph.verify()

return graph

With the inference graph defined, we can now execute it with test images.

### 3. Run inference and check accuracy

To execute the graph, we first convert the model weights from numpy format to
Mojo tensor format, then create the graph, compile it, and run inference.
Finally, to check the accuracy, we iterate on test images, preprocess them,
obtain the result and calls
argmax to
find the predicted value between the *10* classes and count how many of them
correctly match the ground truth label.

`weights_dict = load_model_weights()`

fc1w = numpy_to_tensor[DType.float32](weights_dict["fc1.weight"])

fc1b = numpy_to_tensor[DType.float32](weights_dict["fc1.bias"])

fc2w = numpy_to_tensor[DType.float32](weights_dict["fc2.weight"])

fc2b = numpy_to_tensor[DType.float32](weights_dict["fc2.bias"])

mnist_graph = build_mnist_graph(fc1w^, fc1b^, fc2w^, fc2b^)

session = engine.InferenceSession()

model = session.load(mnist_graph)

correct = 0

total = 0

# use batch size of 1 in this example

test_dataset = load_mnist_test_data()

for i in range(len(test_dataset)):

item = test_dataset[i]

image = item[0]

label = item[1]

preprocessed_image = preprocess(image)

output = model.execute("input0", preprocessed_image)

probs = output.get[DType.float32]("output0")

predicted = probs.argmax(axis=1)

label_ = Tensor[DType.index](TensorShape(1), int(label))

correct += int(predicted == label_)

total += 1

print("Accuracy of the network on the 10000 test images:", 100 * correct / total, "%")

The output of *mojo mnist.mojo* is

`Accuracy of the network on the 10000 test images: 97.310000000000002 %`

This matches the accuracy we observed from the PyTorch test, confirming that our MAX Graph API implementation performs equivalently.

## Create a custom operator with MAX Graph

In this final section of our tutorial, we demonstrate how to create and
register a custom operator to use inside a MAX graph. Following our previous
two layer neural network, we first train our model with *ReLU6* activation via
*python mnist.py —-use-relu6* which replaces *ReLU* with *ReLU6*, checks the
test accuracy and saves the model weights that were done before.

### 1. Implement the custom op

To create a custom operator in Mojo, we should follow these steps

- Create a dedicated sub-repository and name it custom_ops
- Create a
*__init__.mojo*with the import content from*.relu6 import relu6*

Create a custom op Mojo file, *relu6.mojo* with the following code

`from max.extensibility import Tensor, empty_tensor`

from max import register

@register.op("relu6")

fn relu6[type: DType, rank: Int](x: Tensor[type, rank]) -> Tensor[type, rank]:

var output = empty_tensor[type](x.shape)

@always_inline

@parameter

fn _relu6[width: Int](i: StaticIntTuple[rank]) -> SIMD[type, width]:

var val = x.simd_load[width](i)

return val.max(0).min(6)

output.for_each[_relu6]()

return output^

Above code uses *@register.op(“relu6”)* decorator to register the wrapped
*relu6* function with name *”relu6”*, as a custom operator. The wrapped
function can only take max.extensibility
tensors
and must have only one output of the same type and **can not** *raise* an
*Error*. We create an
empty_tensor
to store the output.

To obtain the output, we create a function wrapped in *@parameter* to be
applied on each element of the input tensor via
for_each.
Such function (*_relu6*) loads SIMD values of each rank and applies the
*ReLU6* formula *val.max(0).min(6)*. Finally, we move the output
via *output^* to correctly transfer ownership of the result tensor.

### 2. Add the custom op to the graph

Once we have the custom operator defined, we need to package it as *.mojopkg*
via *mojo package custom_ops*.

In our graph definition, we are now ready to replace the *ops.relu* with our
custom one

`relu = ops.relu(fc1)`

with

`relu = ops.custom["relu6"](fc1, fc1.type())`

Here we use the
ops.custom
that takes the custom operator name *”relu6”* as parameter and the *fc1* as
input and the output type *fc1.type()*. The rest of the code stays the same.

`def build_mnist_graph(`

fc1w: Tensor[DType.float32],

fc1b: Tensor[DType.float32],

fc2w: Tensor[DType.float32],

fc2b: Tensor[DType.float32],

use_relu6: Bool

) -> Graph:

# Note: "batch" is a symbolic dim which is known ahead of time vs dynamic dim

graph = Graph(TensorType(DType.float32, "batch", 28 * 28))

# PyTorch linear is defined as: x W^T + b so we need to transpose the weights

fc1 = (graph[0] @ ops.transpose(graph.constant(fc1w), 1, 0)) + graph.constant(fc1b)

# custom op

relu = ops.custom["relu6"](fc1, fc1.type())

fc2 = (relu @ ops.transpose(graph.constant(fc2w), 1, 0)) + graph.constant(fc2b)

out = ops.softmax(fc2) # adding explicit softmax for inference prob

graph.output(out)

graph.verify()

return graph

The last change is to let the inference *session* know about the custom
operator at runtime via

`model = session.load(mnist_graph, custom_ops_paths=Path("custom_ops.mojopkg"))`

### 3. Verify the results

As the final check, we train and test the model that uses *ReLU6* via *python
mnist.py —-use-relu6* which outputs

`Accuracy of the network on the 10000 test images: 95.91 %`

Then we run the inference code via *mojo mnist.mojo —-use-relu6* which shows

`Accuracy of the network on the 10000 test images: 95.909999999999997 %`

The matching accuracy between the PyTorch version and the Mojo implementation confirms the effective integration of the custom operator.

### 4. Deploy as a binary

For deployment, we can build the *mnist* binary via *mojo build mnist.mojo*.
We can execute the binary as follows

`./mnist`

# or

./mnist --use-relu6

## Next steps

In this tutorial, we demonstrated how to use MAX Graph API step-by-step, to create a symbolic graph, compile and execute such graphs. We also showed how to replicate a two layer neural network trained in PyTorch, in MAX Graph API and saw that the test accuracy remained intact. We concluded by showing how to create and register a custom operator to use for inference. To verify correctness, we showed the test accuracy also remained intact when using such a custom operator. We hope that by the end of this tutorial, you have gained a better understanding of the inner workings of MAX Graph APIs.

Here are a few potential steps for you:

- Explore other neural network architectures beyond a simple two-layer feed-forward network and implement them using MAX Graph API
- Experiment with other custom operator
- Test and assess correctness and contribute to the community 🚀

Report feedback, including issues on our Mojo and MAX GitHub tracker.