Skip to main content

Supported model formats

MAX Engine supports model formats provided by PyTorch, TensorFlow, and ONNX. However, we don't support every format from PyTorch and TensorFlow, and all models must be loaded from a file.

Currently, we support PyTorch's TorchScript format, TensorFlow's SavedModel format, and ONNX format.

If you've never heard of TorchScript, SavedModel, or ONNX, don't worry, that's what this page is here for.

coming soon

We're also adding support for PyTorch 2.0 torch.compile() (TorchDynamo).

Export a PyTorch model to TorchScript

You might be used to saving PyTorch models as a file with torch.save(), but this creates a Python pickle object, which can only be used with Python. This won't work because MAX Engine doesn't execute models with Python. You instead need to create a TorchScript file.

As the name implies, TorchScript is actually a language—it's a subset of the Python language (technically, it's an intermediate representation) that provides a serialization format for PyTorch models so they can run in non-Python environments. However, you can't load TorchScript code directly into MAX Engine—you must save it as a TorchScript file.

As we'll describe in the following sections, PyTorch provides two ways to convert a PyTorch model (a torch.nn.Module object) to TorchScript: you can either "script" or "trace" the model.

Script a PyTorch model

When you "script" a PyTorch model, it means that you parse the Python code and convert it into the TorchScript representation using torch.jit.script(). So, although it's accurate to say TorchScript is a language, you still write everything in Python (your code lives in a .py file). However, not all Python code can be successfully converted into TorchScript.

In the best-case scenario, all the Python code in your PyTorch model is already compatible with TorchScript and calling torch.jit.script() just works. In other cases, it might require that you modify the Python code so it uses only the Python features that are available in the TorchScript language—most notably, TorchScript enforces static types.

When your PyTorch model is compatible with TorchScript, calling torch.jit.script() returns either a ScriptModule or ScriptFunction, which you can save as a file with torch.jit.save().

Fortunately, many PyTorch models are already compatible with TorchScript, so you can simply instantiate them, convert them, and save them as a TorchScript file like this:

import torch
import torchvision.models as models

r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')

This resnet50.torchscript file is now ready to load into MAX Engine.

If you're writing your own PyTorch model and want to make it compatible with TorchScript, see the PyTorch docs about TorchScript for more details.

Trace a PyTorch model

In some cases, scripting a model might not work so easily, and making it work could require significant code rewrites. In this case, you can instead "trace" the graph with torch.jit.trace().

Tracing the model means PyTorch actually executes the model with sample inputs and records all operations that are invoked. PyTorch adds the recorded operations to a TorchScript representation of the graph—either a ScriptModule or ScriptFunction—that you can save as a TorchScript file with torch.jit.save().

For example, the following code shows how you can trace a model from 🤗 Transformers and save it as a TorchScript file. Notice that tracing a model requires that you provide sample input data so that torch.jit.trace() can actually execute the model (it can be random data as long as it matches the input shape and type).

import torch
from transformers import RobertaForSequenceClassification

HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)

batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)

torch.jit.save(traced_model, "roberta.torchscript")

This roberta.torchscript file is now ready to load into MAX Engine.

note

When tracing a model, only the executed path of the graph is recorded into the ScriptFunction. That means any control flow logic (such as if/else conditions) are lost. If your model includes any branching or other dynamism, you should instead script your PyTorch model.

For more information, see the PyTorch docs about TorchScript.

Export a TensorFlow model to SavedModel

A TensorFlow SavedModel contains everything required for MAX Engine to compile and execute the model. There are several ways you can create a SavedModel, as described in the TensorFlow docs about using SavedModel.

If you've built your model with the Keras API, just call Model.save() with the save_format="tf" argument.

For example, because all 🤗 Transformers models such as TFRobertaForSequenceClassification are subclasses of tf.keras.Model, you can save them as a SavedModel like this:

from transformers import TFRobertaForSequenceClassification

HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = TFRobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)
model.save('roberta', save_format="tf")

The roberta SavedModel is now ready to load into MAX Engine.

However, if you're using Keras 3, the above model.save() function is no longer supported, and you need to instead use tf.saved_model.save(). For example:

import keras
import numpy as np
import tensorflow as tf

sequential_model = keras.Sequential([keras.layers.Dense(2)])
sequential_model(np.random.rand(3, 5))
tf.saved_model.save(sequential_model, "keras_saved_model")
note

A SavedModel is a directory of files and that's the path you'll provide MAX Engine when you load the model.

Convert a model to ONNX

MAX Engine supports models in the ONNX format, which you can create from either PyTorch or TensorFlow. If you already have an ONNX model, you can directly load it into MAX Engine.

If you don't already have an ONNX model, then we recommend that you instead create a TorchScript file or create a SavedModel, as described above. This will save you a bit of time and confusion.

That's not to say ONNX isn't any good, because it is good, and it provides plenty of value for a wide range of production use cases. However, if your intent is to use MAX Engine, then using ONNX doesn't really help you because you can get the format you need straight from PyTorch or TensorFlow, as shown above.

Load a model into MAX Engine

Once you have your model as a TorchScript, SavedModel, or ONNX file, you can load it for execution in MAX Engine.

If you're using a TorchScript file, you need to first specify the input shape.

If you're using a SavedModel or ONNX file, just pass the file path to MAX Engine, as shown in the guides to run inference with Python, with C, and with Mojo.

note

If you want to load your model with the max benchmark command, and your model has a fixed range of allowed values, then you need to specify the allowed range with an input data schema file. Otherwise, the benchmark tool might crash at runtime if it generates inputs that are outside the acceptable range.

Specify TorchScript input specs

Loading a TorchScript model requires an extra step because the MAX Engine compiler must know the model's input shape, rank, and data type, which is absent in a TorchScript model. Thus, when loading a TorchScript model, you must provide the shape, rank, and data type as "input specs." If the model supports inputs with dynamic shapes, you can specify those in the input specs and MAX Engine will optimize the model for any inputs that match the shape.

The exact syntax to specify the input specs is different for each API:

If you're instead loading a model into the max benchmark or max visualize command, then you need to specify the input shapes using an input data schema file.