Create a custom op for any model
MAX is designed to be fully extensible, so you can get the behavior you want from your model, without any compromises in performance. That means you can extend MAX Engine's capabilities by defining high-performance custom operations ("ops") for your model. On this page, we’ll show you how to implement a new op for MAX Engine and use it with an ONNX model, step by step.
When you write a custom op in Mojo, the MAX Engine compiler treats it the same as all the other ops (”kernels”) that we’ve built into MAX Engine—the compiler will analyze and optimize your op to achieve the best performance. You can learn more about how it works in the intro to MAX extensibility.
The procedure is basically two steps:
- Write your op as a Mojo function that operates on
Tensor
values, and package it. - When you load your model, also load your op package (for example, with the
Python
load()
function).
Let's do it!
Currently, custom ops are compatible with ONNX and MAX Graph models only. Support for TorchScript models is coming soon. (If you're using MAX Graph, instead see Create a custom op in MAX Graph).
Also, the example code below currently fails on Ubuntu 20.04.
Overview
When MAX compiles a TorchScript or ONNX model, it translates each op in the
graph into “MO” dialect operations. The MO dialect is an intermediate
representation we use in the our MLIR-based compiler
graph compiler (”MO” is for “Modular”). This means we’ve defined a lot of graph
operations in MO that map to other framework ops, such as add
, matmul
, and
softmax
. But there are a lot of ops out there (PyTorch has thousands),
most of which are rarely used and currently not implemented in MAX.
If you’re using a TorchScript model and the MAX compiler encounters one of these ops that we haven’t implemented, it falls back to using the op implementation from PyTorch. This is a good thing because it means we can provide full compatibility with nearly all PyTorch models, without any need for custom ops. (Currently, you can't implement custom ops for a PyTorch model anyways, but that's in development.)
On the other hand, if your ONNX model uses an op we haven’t implemented in MAX, the compiler fails because ONNX models cannot fall back to the op implementation from ONNX. In this case, you can fix it by implementing the op in Mojo. That’s exactly what you’ll learn to do below.
We're going to implement the Det
op for MAX Engine. Currently,
MAX Engine doesn't have an implementation for this op, so it will fail when
trying to compile an ONNX model that uses it. (Whereas, you can compile a
TorchScript model that uses Det
, because MAX Engine falls back to the
PyTorch implementation.)
Set up
First, make sure you have the latest version of MAX.
For this tutorial, we're going to use some example code in our max GitHub repo. If you haven't already, clone the repo:
git clone https://github.com/modularml/max.git && cd examples/extensibility
And create a virtual environment with the required dependencies:
python3.11 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
pip install --find-links $(modular config max.path)/wheels max-engine
Build the ONNX model
Initially, the ONNX model doesn't work. Here's how to build it and see for yourself that it fails to compile with MAX Engine:
-
Build the model with this script:
python3 onnx-model.py
-
Try running it with the benchmark tool:
max benchmark onnx_det.onnx
This will fail with an error:
loc("onnx_det.onnx":0:0): error: failed to legalize operation 'monnx.det_v11' that was explicitly marked illegal
As the error indicates, MAX Engine cannot compile the det_v11
op.
Now let’s implement that op for MAX Engine so the model works.
The custom op function we need already exists in the GitHub repo, but if you want to build it yourself as practice, create a new directory and copy the ONNX model you just built:
mkdir my_project && mv onnx_det.onnx my_project && cd my_project
Implement a custom op
To create a custom op for MAX Engine, you need to create a Mojo package that includes the code for your op.
Start by creating a directory to use as the Mojo package, and create a file
named det.mojo
inside:
mkdir custom_ops && touch custom_ops/det.mojo
Then write the function, following these rules:
-
The function must operate on
Tensor
values: it must take aTensor
argument for each op input and return aTensor
as the op output. -
You must register your function as an op by adding the
register.op()
decorator with the name of your op as an argument.
For example, see the following code—you can copy and paste this into your
det.mojo
file. Notice the name we give to register.op()
is the name from
the above compiler error ("monnx.det_v11"
):
from python import Python
from .python_utils import tensor_to_numpy, numpy_to_tensor
from max import register
from max.extensibility import Tensor, empty_tensor
@register.op("monnx.det_v11")
fn det[type: DType, rank: Int](x: Tensor[type, rank]) -> Tensor[type, rank - 2]:
try:
print("Hello, custom DET!")
var np = Python.import_module("numpy")
var np_array = tensor_to_numpy(x, np)
var np_out = np.linalg.det(np_array)
return numpy_to_tensor[type, rank - 2](np_out)
except e:
print(e)
return empty_tensor[type, rank - 2](0)
Although we could have written the determinant ("det") function entirely in
Mojo, we don't have to because Mojo allows us to lazily use the NumPy
implementation.
This requires that we convert NumPy arrays to/from Tensor
values. Our code
does that using tensor_to_numpy()
and numpy_to_tensor()
, which come from
the python_utils
module in our GitHub code example.
For the sake of this tutorial, create a file named python_utils.mojo
right
next to your det.mojo
file and copy-paste the contents of
python_utils.mojo
.
Now the custom op is ready to go.
Package the custom op
To package the custom op, add an empty __init__.mojo
file in the custom_ops
directory. Then, pass that directory name to the mojo package
command.
For example, here's a look in our custom_ops
directory:
custom_ops
├── __init__.mojo
├── det.mojo
└── python_utils.mojo
Now package it with this command:
mojo package custom_ops
This creates a file called custom_ops.mojopkg
in the current directory.
Next, you simply load the ONNX model with this Mojo package.
Benchmark with the custom op
To quickly verify the custom op works, pass the model and the Mojo package to
the max benchmark
command:
max benchmark onnx_det.onnx --custom-ops-path=custom_ops.mojopkg
Execute the model with the custom op
Finally, here’s how to use the Python API to
load the model and the custom op. All you need to do is add the
custom_ops_path
argument when you call
load()
.
So, create a file named onnx-inference.py
in the my_project
directory, and
paste in this code:
from max import engine
import numpy as np
session = engine.InferenceSession()
model = session.load("onnx_det.onnx", custom_ops_path="custom_ops.mojopkg")
for tensor in model.input_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
input_x = np.random.rand(3, 3, 5).astype(np.float32)
input_a = np.random.rand(5, 3).astype(np.float32)
input_b = np.random.rand(3).astype(np.float32)
result = model.execute(X=input_x, A=input_a, B=input_b)
print(result)
That’s it!
Now run it:
python3 onnx-inference.py
Compiling model...
Done!
name: X, shape: [3, 3, 5], dtype: DType.float32
name: A, shape: [5, 3], dtype: DType.float32
name: B, shape: [3], dtype: DType.float32
Hello, custom DET!
{'Z': array([-0.04415698, -0.00949615, 0.07051321], dtype=float32)}
You can get all the example code from this and previous sections from our GitHub repo.
Add a custom op to Triton (optional)
If you're using NVIDIA's Triton Inference Server to deploy your model, you can make your custom op available by appending the following to your model configuration file:
parameters: [{
key: "custom-ops-path"
value: {
string_value:"./path/to/custom_ops.mojopkg"
}
}]
You must include the key-value pair inside the parameters
configuration, as
shown above. The only thing you need to change is the string_value
so it
specifies the path to your custom ops Mojo package.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!
If you'd like to share more information, please report an issue on GitHub
😔 What went wrong?