Model formats
We designed MAX to simplify AI deployment for everybody, which means it can accelerate all kinds of models on all kinds of hardware. Depending on your task, maybe that means running a model from PyTorch and ONNX, or maybe that means running the latest generative AI (GenAI) models. In all cases, MAX provides you a solution to deploy with hardware flexibility and state-of-the-art inference performance.
This page explains each of the model formats that MAX supports, and how you can get them. Put simply, you can use MAX with the following formats:
- TorchScript (PyTorch)
- ONNX
- MAX Graph
TorchScript
If you're familiar with PyTorch, you might be used to saving your model with
torch.save()
.
But this creates a Python pickle
object, which can only be used
from Python. This won't work because, under the hood, MAX doesn't execute
models with Python. You instead need to create a
TorchScript file that MAX can
compile into an executable format that MAX understands.
As we'll describe in the following sections, PyTorch provides two ways to
convert a PyTorch model (a
torch.nn.Module
object) to TorchScript: you can either "script" or "trace" the model.
Script a PyTorch model
When you "script" a PyTorch model, it means that you parse the Python code and
convert it into the TorchScript representation using
torch.jit.script()
.
Although it's accurate to say TorchScript is a language, you still write
everything in Python (your code lives in a .py
file). However, not all Python
code can successfully convert to TorchScript.
In the best-case scenario, all the Python code in your PyTorch model is already
compatible with TorchScript and calling
torch.jit.script()
just works. In other cases, you might need to modify the Python code so
it uses only the Python features that are available in the TorchScript
language—most
notably, TorchScript enforces static types.
When your PyTorch model is compatible with TorchScript, calling
torch.jit.script()
returns either a
ScriptModule
or
ScriptFunction
,
which you can save as a file with
torch.jit.save()
.
Fortunately, a lot of PyTorch models are already compatible with TorchScript, so you can simply instantiate them, convert them, and save them as a TorchScript file like this:
import torch
import torchvision.models as models
r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')
import torch
import torchvision.models as models
r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')
This resnet50.torchscript
file is now ready to load and execute with the MAX
Engine API.
Trace a PyTorch model
In some cases, scripting a model as shown above might require significant code
rewrites. In this case, you can instead trace the graph with
torch.jit.trace()
.
Tracing the model means PyTorch executes the model with sample inputs
and records all operations that are invoked. PyTorch adds the recorded
operations to a TorchScript representation of the graph—either a
ScriptModule
or
ScriptFunction
—that
you can save as a TorchScript file with
torch.jit.save()
.
For example, the following code shows how you can trace a model from Hugging Face and save it as a TorchScript file.
import torch
from transformers import RobertaForSequenceClassification
HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)
batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)
torch.jit.save(traced_model, "roberta.torchscript")
import torch
from transformers import RobertaForSequenceClassification
HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)
batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)
torch.jit.save(traced_model, "roberta.torchscript")
Notice that tracing a model requires
that you provide sample input data so that
torch.jit.trace()
can execute the model (it can be random data as long as it matches the input
shape and type).
For a complete code example of how to trace a model and run it with MAX, see
our tutorial to run a TorchScript model with
Python.
ONNX
ONNX (Open Neural Network Exchange) is a model format you can export from a variety of machine learning tools, including PyTorch, TensorFlow, Keras, Scikit-Learn, and more.
If you don't have an ONNX model yet, refer to the appropriate ML framework documentation about how to convert your model to ONNX. For example, if using TensorFlow, see the ONNX guide about how to convert TensorFlow to ONNX.
It's also easy to export a Hugging Face model to ONNX, using either a CLI tool or a Python API.
For a complete code example of how to export a Hugging Face model to ONNX with Python and then run inference with MAX, see our tutorial to run an ONNX model with Python.
MAX Graph
MAX Graph is our solution to build high-performance GenAI models such as large-language models (LLMs) in Python. When paired with the MAX compiler and runtime, a MAX Graph model delivers the state-of-the-art performance you'd expect from point-solution AI libraries written in C or C++, but with less code that's more readable.
We built MAX Graph because, although MAX can execute off-the-shelf models from PyTorch and ONNX faster than the default runtimes, these ML frameworks can't do everything. Specifically, they have not kept up with the performance demands of GenAI model.
It all starts with the Graph
object for
creating acyclic computation graphs:
- Instantiate a graph, specifying the input shape as a
TensorType
. - Build the graph by chaining
ops
functions. Each function takes and returns aValue
object. - Add the final Value to the graph using the
output()
method.
This structure allows you to define complex computations as a series of interconnected operations, forming an AI model graph.
For a step-by-step guide to building a model with MAX Graph, see our tutorial to get started with MAX Graph in Python. Or, check out our implementation of some GenAI models such as Llama3 and Mistral in GitHub.
Get started
Run an ONNX model with Python
Learn how to run inference with an ONNX model, using our Python API.
Run an TorchScript model with Python
Learn how to run inference with a PyTorch TorchScript model, using our Python API.
Get started with MAX Graph in Python
Learn how to build a model graph with our Python API for inference with MAX Engine.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!