Supported model formats
MAX Engine supports model formats provided by PyTorch, TensorFlow, and ONNX. However, we don't support every format from PyTorch and TensorFlow, and all models must be loaded from a file.
Currently, we support PyTorch's TorchScript format, TensorFlow's SavedModel format, and ONNX format.
If you've never heard of TorchScript, SavedModel, or ONNX, don't worry, that's what this page is here for.
We're also adding support for PyTorch 2.0 torch.compile()
(TorchDynamo).
Export a PyTorch model to TorchScript
You might be used to saving PyTorch models as a file with
torch.save()
,
but this creates a Python pickle
object, which can only be used
with Python. This won't work because MAX Engine doesn't execute models with
Python. You instead need to create a TorchScript file.
As the name implies, TorchScript is actually a language—it's a subset of the Python language (technically, it's an intermediate representation) that provides a serialization format for PyTorch models so they can run in non-Python environments. However, you can't load TorchScript code directly into MAX Engine—you must save it as a TorchScript file.
As we'll describe in the following sections, PyTorch provides two ways to
convert a PyTorch model (a
torch.nn.Module
object) to TorchScript: you can either "script" or "trace" the model.
Script a PyTorch model
When you "script" a PyTorch model, it means that you parse the Python code and
convert it into the TorchScript representation using
torch.jit.script()
.
So, although it's accurate to say TorchScript is a language, you still write
everything in Python (your code lives in a .py
file). However, not all Python
code can be successfully converted into TorchScript.
In the best-case scenario, all the Python code in your PyTorch model is already
compatible with TorchScript and calling
torch.jit.script()
just works. In other cases, it might require that you modify the Python code so
it uses only the Python features that are available in the TorchScript
language—most
notably, TorchScript enforces static types.
When your PyTorch model is compatible with TorchScript, calling
torch.jit.script()
returns either a
ScriptModule
or
ScriptFunction
,
which you can save as a file with
torch.jit.save()
.
Fortunately, many PyTorch models are already compatible with TorchScript, so you can simply instantiate them, convert them, and save them as a TorchScript file like this:
import torch
import torchvision.models as models
r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')
This resnet50.torchscript
file is now ready to load into MAX
Engine.
If you're writing your own PyTorch model and want to make it compatible with TorchScript, see the PyTorch docs about TorchScript for more details.
Trace a PyTorch model
In some cases, scripting a model might not work so easily, and making it work
could require significant code rewrites. In this case, you can instead "trace"
the graph with
torch.jit.trace()
.
Tracing the model means PyTorch actually executes the model with sample inputs
and records all operations that are invoked. PyTorch adds the recorded
operations to a TorchScript representation of the graph—either a
ScriptModule
or
ScriptFunction
—that
you can save as a TorchScript file with
torch.jit.save()
.
For example, the following code shows how you can trace a model from 🤗
Transformers and save it as a TorchScript file. Notice that tracing a model
requires that you provide sample input data so that
torch.jit.trace()
can actually execute the model (it can be random data as long as it matches
the input shape and type).
import torch
from transformers import RobertaForSequenceClassification
HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)
batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)
torch.jit.save(traced_model, "roberta.torchscript")
This roberta.torchscript
file is now ready to load into MAX
Engine.
When tracing a model, only the executed path of the graph is recorded into the
ScriptFunction
. That means any control flow logic (such as if/else
conditions) are lost. If your model includes any branching or other dynamism,
you should instead script your PyTorch model.
For more information, see the PyTorch docs about TorchScript.
Export a TensorFlow model to SavedModel
Due to a lack of customer interest, we've removed TensorFlow support in the MAX SDK. This significantly reduces our package size and accelerates development of other customer-requested features. If you have a production use-case with a TensorFlow model, please contact us.
See the documentation anyways.
A TensorFlow SavedModel contains everything required for MAX Engine to compile and execute the model. There are several ways you can create a SavedModel, as described in the TensorFlow docs about using SavedModel.
If you've built your model with the Keras
API, just call
Model.save()
with the save_format="tf"
argument.
For example, because all 🤗 Transformers models such as
TFRobertaForSequenceClassification
are subclasses of tf.keras.Model
, you can save them as a SavedModel like
this:
from transformers import TFRobertaForSequenceClassification
HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = TFRobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)
model.save('roberta', save_format="tf")
The roberta
SavedModel is now ready to load into MAX
Engine.
However, if you're using Keras 3, the above
model.save()
function is no longer
supported,
and you need to instead use
tf.saved_model.save()
.
For example:
import keras
import numpy as np
import tensorflow as tf
sequential_model = keras.Sequential([keras.layers.Dense(2)])
sequential_model(np.random.rand(3, 5))
tf.saved_model.save(sequential_model, "keras_saved_model")
A SavedModel is a directory of files and that's the path you'll provide MAX Engine when you load the model.
Convert a model to ONNX
MAX Engine supports models in the ONNX format, which you can create from either PyTorch or TensorFlow. If you already have an ONNX model, you can directly load it into MAX Engine.
If you don't already have an ONNX model, then we recommend that you instead create a TorchScript file or create a SavedModel, as described above. This will save you a bit of time and confusion.
That's not to say ONNX isn't any good, because it is good, and it provides plenty of value for a wide range of production use cases. However, if your intent is to use MAX Engine, then using ONNX doesn't really help you because you can get the format you need straight from PyTorch or TensorFlow, as shown above.
Load a model into MAX Engine
Once you have your model as a TorchScript, SavedModel, or ONNX file, you can load it for execution in MAX Engine.
If you're using a TorchScript file, you need to first specify the input shape.
If you're using a SavedModel or ONNX file, just pass the file path to MAX Engine, as shown in the guides to run inference with Python, with C, and with Mojo.
If you want to load your model with the max benchmark
command, and your model has a
fixed range of allowed values, then you need to specify the allowed range
with an input data schema file.
Otherwise, the benchmark tool might crash at runtime if it generates inputs
that are outside the acceptable range.
Specify TorchScript input specs
Loading a TorchScript model requires an extra step because the MAX Engine compiler must know the model's input shape, rank, and data type, which is absent in a TorchScript model. Thus, when loading a TorchScript model, you must provide the shape, rank, and data type as "input specs." If the model supports inputs with dynamic shapes, you can specify those in the input specs and MAX Engine will optimize the model for any inputs that match the shape.
The exact syntax to specify the input specs is different for each API:
-
In Mojo and Python, you need to specify the
input_specs
keyword argument toInferenceSession.load
(Mojo doc, Python doc). For details, see the Mojo inferencing guide or Python inferencing guide. -
In C, you need to call
M_setTorchInputSpecs()
. For details, see the C inferencing guide.
If you're instead loading a model into the max benchmark
or max visualize
command, then you need to specify
the input shapes using an input data schema
file.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!
If you'd like to share more information, please report an issue on GitHub
😔 What went wrong?