Build a model with MAX Graph API
If you want to take full advantage of MAX Engine for high-performance inference, you can use the MAX Graph APIs to build your model in Mojo. It’s not a machine learning framework (you can’t train models written with the Graph APIs, and they don’t offer high-level abstractions like neural network layers)—the Graph APIs allow you to build, import, and tweak the data flow graphs that execute in MAX Engine (you must load the model weights yourself).
We built the MAX Graph API because, although MAX Engine can execute models from frameworks such as PyTorch and TensorFlow faster than the default runtimes, sometimes the high-level abstractions leave performance on the table. With the MAX Graph API, we want to give you full control over the graphs that the engine executes.
The MAX Graph API allows you to directly manipulate the low-level graph in Mojo—a programming language that provides low-level systems programming for high performance code, but with the usability of Python. By building your graph in Mojo, you can build highly performant graphs, but with much less code that’s more readable compared to other performance-based graph libraries in C or C++.
To get started with the Graph API, let’s look at the typical "hello world" model: a feed-forward network to classify numbers in the MNIST dataset.
For a more exciting model built with the MAX Graph API, checkout our Llama2 Mojo example.
Create a Graph
To get started, create a
Graph
. A
Graph
is like a function: it
has a name, takes arguments, runs computations and returns some values.
A Graph
is created inside a
Module
, which is conceptually
is just a package that holds together related
Graph
objects. The
Module
direct equivalent of
MLIR’s
ModuleOp
,
similar to the ONNX
ModelProto
.
from max import graph
fn main() raises:
var m = graph.Module()
var g = m.graph(
"my_model",
in_types=graph.MOTensor(DType.float32, 1, 28, 28, 1),
out_types=graph.MOTensor(DType.float32, 1, 10),
)
print(str(m))
The Module.graph()
function
has three notable arguments:
- The graph’s name, which is just an arbitrary string.
- The input types, which denote the types for the graph input arguments.
- The output types, which denote the types for the graph return values.
To describe types, the Graph API uses a set of specialized “type” objects, like
MOTensor
and MOList
. These represent exactly what their name suggests, and
are built specifically to describe the data types that you can use in a MAX
graph.
In the example above, the graph has a single argument of type “tensor”:
MOTensor(DType.float32, 1, 28, 28, 1)
. The tensor is of shape 1x28x28x1
and float32
data type.
Similarly, the graph has a single return value, which is a 1x10
, float32
tensor.
To specify multiple arguments or return types, or none at all, pass a
max.TypeTuple
instead.
Add ops to the Graph
After creating a Graph
, you
can add ops to it. All Graph
operations flow from either graph inputs or constant values, and you can create
new values by passing existing values through operation functions.
For example, here’s how to add a constant to the graph with the
Graph.constant()
helper:
from tensor import Tensor, TensorShape
var constant_value = Tensor[DType.float32](128, 10)
constant_value._to_buffer().fill(0.5)
var cst = g.constant(constant_value)
var cst_0 = g.constant(
Tensor[DType.float32](
TensorShape(1, 10),
-0.0675942451,
0.0063267909,
7.43086217e-4,
-0.0126994187,
0.0148473661,
0.108896509,
-0.0398316309,
0.0461452715,
-0.0281771384,
-0.0431172103,
)
)
The cst
variable above has type
Symbol
, and represents the
symbolic result value of the g.constant
op. Because we are building a graph,
computations don’t return actual values, but return symbolic handles. These let
you identify the output or input of an op, but they have no concrete value,
because values are known only at runtime. Much like a value, a
Symbol
is passed as input to
another op to form increasingly complicated expressions in the
Graph
.
Ultimately, there are several ways to create ops:
- Using helper methods of in the
Graph
struct, likeGraph.constant()
,Graph.range()
,Graph.output()
as shown right above. These are typically specialized ops, or ops that only have constant inputs. - Using various helper functions from the
max.graph.ops
module, likeops.squeeze()
,ops.list()
,ops.concat()
, and so on. (See an example below.) - Using operators and methods of the
Symbol
struct, likeSymbol.reshape()
,Symbol.__add__()
, and others. These are shorthand helpers that allow more compact expressions; most havemax.graph.ops
counterparts.
Add more ops to the Graph
Let’s add a couple more constants to the
Graph
to complete our MNIST
classification model:
var cst_1_data = Tensor[DType.float32](784, 128)
cst_1_data._to_buffer().fill(0.5)
var cst_1 = g.constant(cst_1_data)
var cst_2_data = Tensor[DType.float32](1, 128)
cst_2_data._to_buffer().fill(0.5)
var cst_2 = g.constant(cst_2_data)
Then let’s add a reshape op:
var p1 = g[0].reshape(1, 784)
Notice that we referenced the op using g[0]
, which denotes the
Graph
's first argument.
Arguments are also Symbol
objects, so you can use them when in the same way as the output of a regular
op. This particular reshape is created using
Symbol.reshape()
.
Add one more op, this time using Symbol.__matmul__
:
var p2 = p1 @ cst_1
Add a few more ops to complete our MNIST model, using a mix operators and op helpers:
from max.graph import ops
var p3 = ops.add(p2, cst_2)
var p4 = ops.elementwise.relu(p3)
var p5 = p4 @ cst
var p6 = p5 + cst_0
Return results from the Graph
Every Graph
has a special op
that denotes the return values. Create it using
Graph.output()
:
g.output(p6)
That’s it!
We now have a complete Graph
and can see it with the str()
function:
print(str(m))
This prints the MLIR that represents the MAX graph.
Use external weights
Normally, you run the graph with pre-trained weights loaded from a source. The
most practical way to to that is to use standard Mojo APIs, like
Tensor.load()
, or implement your own data
loading code in Mojo and use one of the other
Tensor
constructors, such as the one
that receives a pointer to a data buffer.
For an example of the latter strategy, check out the Llama 2 MAX Graph example.
Execute the Graph
Once the Graph
is built, you
can execute it using the MAX Engine Mojo API:
from max.engine import InferenceSession, Model, NamedTensor
var session = InferenceSession()
var compiled_model = session.load_model(m)
# Create a test input image, fill it with flat values
var x = Tensor[DType.float32](1, 28, 28, 1)
x._to_buffer().fill(0.5)
var result_map = compiled_model.execute(NamedTensor("input0", x ^))
var predictions = result_map.get[DType.float32]("output0")
print(predictions)
That's all for now!
This has been just a basic introduction to the MAX Graph API, and there is still much more to come! To stay up to date on what's coming, sign up for our newsletter. Also, talk to other MAX developers, ask questions, and share feedback on Discord and GitHub.