# Build a model with MAX Graph API

If you want to take full advantage of MAX Engine for high-performance inference, you can use the MAX Graph APIs to build your model in Mojo. It’s not a machine learning framework (you can’t train models written with the Graph APIs, and they don’t offer high-level abstractions like neural network layers)—the Graph APIs allow you to build, import, and tweak the data flow graphs that execute in MAX Engine (you must load the model weights yourself).

We built the MAX Graph API because, although MAX Engine can execute models from frameworks such as PyTorch and TensorFlow faster than the default runtimes, sometimes the high-level abstractions leave performance on the table. With the MAX Graph API, we want to give you full control over the graphs that the engine executes.

The MAX Graph API allows you to directly manipulate the low-level graph in Mojo—a programming language that provides low-level systems programming for high performance code, but with the usability of Python. By building your graph in Mojo, you can build highly performant graphs, but with much less code that’s more readable compared to other performance-based graph libraries in C or C++.

To get started with the Graph API, let’s look at the typical "hello world" model: a feed-forward network to classify numbers in the MNIST dataset.

For a more exciting model built with the MAX Graph API, checkout our Llama2 Mojo example.

## Create a Graph

To get started, create a
`Graph`

. A
`Graph`

is like a function: it
has a name, takes arguments, runs computations and returns some values.

A `Graph`

is created inside a
`Module`

, which is conceptually
is just a package that holds together related
`Graph`

objects. The
`Module`

direct equivalent of
MLIR’s
`ModuleOp`

,
similar to the ONNX
`ModelProto`

.

`from max import graph`

fn main() raises:

var m = graph.Module()

var g = m.graph(

"my_model",

in_types=graph.MOTensor(DType.float32, 1, 28, 28, 1),

out_types=graph.MOTensor(DType.float32, 1, 10),

)

print(str(m))

The `Module.graph()`

function
has three notable arguments:

- The graph’s name, which is just an arbitrary string.
- The input types, which denote the types for the graph input arguments.
- The output types, which denote the types for the graph return values.

To describe types, the Graph API uses a set of specialized “type” objects, like
`MOTensor`

and `MOList`

. These represent exactly what their name suggests, and
are built specifically to describe the data types that you can use in a MAX
graph.

In the example above, the graph has a single argument of type “tensor”:
`MOTensor(DType.float32, 1, 28, 28, 1)`

. The tensor is of shape `1x28x28x1`

and `float32`

data type.

Similarly, the graph has a single return value, which is a `1x10`

, `float32`

tensor.

To specify multiple arguments or return types, or none at all, pass a
`max.TypeTuple`

instead.

## Add ops to the Graph

After creating a `Graph`

, you
can add ops to it. All `Graph`

operations flow from either graph inputs or constant values, and you can create
new values by passing existing values through operation functions.

For example, here’s how to add a constant to the graph with the
`Graph.constant()`

helper:

`from tensor import Tensor, TensorShape`

var constant_value = Tensor[DType.float32](128, 10)

constant_value._to_buffer().fill(0.5)

var cst = g.constant(constant_value)

var cst_0 = g.constant(

Tensor[DType.float32](

TensorShape(1, 10),

-0.0675942451,

0.0063267909,

7.43086217e-4,

-0.0126994187,

0.0148473661,

0.108896509,

-0.0398316309,

0.0461452715,

-0.0281771384,

-0.0431172103,

)

)

The `cst`

variable above has type
`Symbol`

, and represents the
symbolic result value of the `g.constant`

op. Because we are building a graph,
computations don’t return actual values, but return symbolic handles. These let
you identify the output or input of an op, but they have no concrete value,
because values are known only at runtime. Much like a value, a
`Symbol`

is passed as input to
another op to form increasingly complicated expressions in the
`Graph`

.

Ultimately, there are several ways to create ops:

- Using helper methods of in the
`Graph`

struct, like`Graph.constant()`

,`Graph.range()`

,`Graph.output()`

as shown right above. These are typically specialized ops, or ops that only have constant inputs. - Using various helper functions from the
`max.graph.ops`

module, like`ops.squeeze()`

,`ops.list()`

,`ops.concat()`

, and so on. (See an example below.) - Using operators and methods of the
`Symbol`

struct, like`Symbol.reshape()`

,`Symbol.__add__()`

, and others. These are shorthand helpers that allow more compact expressions; most have`max.graph.ops`

counterparts.

## Add more ops to the Graph

Let’s add a couple more constants to the
`Graph`

to complete our MNIST
classification model:

`var cst_1_data = Tensor[DType.float32](784, 128)`

cst_1_data._to_buffer().fill(0.5)

var cst_1 = g.constant(cst_1_data)

var cst_2_data = Tensor[DType.float32](1, 128)

cst_2_data._to_buffer().fill(0.5)

var cst_2 = g.constant(cst_2_data)

Then let’s add a reshape op:

`var p1 = g[0].reshape(1, 784)`

Notice that we referenced the op using `g[0]`

, which denotes the
`Graph`

's first argument.
Arguments are also `Symbol`

objects, so you can use them when in the same way as the output of a regular
op. This particular reshape is created using
`Symbol.reshape()`

.

Add one more op, this time using `Symbol.__matmul__`

:

`var p2 = p1 @ cst_1`

Add a few more ops to complete our MNIST model, using a mix operators and op helpers:

`from max.graph import ops`

var p3 = ops.add(p2, cst_2)

var p4 = ops.elementwise.relu(p3)

var p5 = p4 @ cst

var p6 = p5 + cst_0

## Return results from the Graph

Every `Graph`

has a special op
that denotes the return values. Create it using
`Graph.output()`

:

`g.output(p6)`

That’s it!

We now have a complete `Graph`

and can see it with the `str()`

function:

`print(str(m))`

This prints the MLIR that represents the MAX graph.

## Use external weights

Normally, you run the graph with pre-trained weights loaded from a source. The
most practical way to to that is to use standard Mojo APIs, like
`Tensor.load()`

, or implement your own data
loading code in Mojo and use one of the other
`Tensor`

constructors, such as the one
that receives a pointer to a data buffer.

For an example of the latter strategy, check out the Llama 2 MAX Graph example.

## Execute the Graph

Once the `Graph`

is built, you
can execute it using the MAX Engine Mojo API:

`from max.engine import InferenceSession, Model, NamedTensor`

var session = InferenceSession()

var compiled_model = session.load_model(m)

# Create a test input image, fill it with flat values

var x = Tensor[DType.float32](1, 28, 28, 1)

x._to_buffer().fill(0.5)

var result_map = compiled_model.execute(NamedTensor("input0", x ^))

var predictions = result_map.get[DType.float32]("output0")

print(predictions)

That's all for now!

This has been just a basic introduction to the MAX Graph API, and there is still much more to come! To stay up to date on what's coming, sign up for our newsletter. Also, talk to other MAX developers, ask questions, and share feedback on Discord and GitHub.