Create a custom op in MAX Graph
The MAX Graph API includes growing list of operations in the
graph.ops
package that you can add to
your graph. However, this library might be missing an op that you need, or
perhaps you want to write your own implementation for an existing op. As we'll
show you on this page, that's no problem, because you can write your own ops in
Mojo and add easily add them to a graph you built with MAX
Graph.
When you write a custom op in Mojo, the MAX Engine compiler treats it the same as all the other ops (”kernels”) that we’ve built into MAX Engine—the compiler will analyze and optimize the op to achieve the best performance—your custom op is not treated as an external library.
Before you begin, be sure you install the latest version of MAX.
If you're loading an external model, instead see how to create a custom op for any model.
Implement the custom op
To create a custom op, write a Mojo function that operates on
Tensor
values: it must take a
Tensor
argument for each op input and return a Tensor
as the op output.
You must also register your function as an op by adding the
register.op()
decorator with
the name of your op. (This name must match the custom op name you’ll later add
to the Graph
.)
For example, here’s a custom implementation of the GELU op:
from max.extensibility import Tensor, empty_tensor
from max import register
from math import erf, sqrt
@register.op("my_gelu")
fn gelu[type: DType, rank: Int](x: Tensor[type, rank]) -> Tensor[type, rank]:
var output = empty_tensor[type](x.shape)
@always_inline
@parameter
fn func[width: Int](i: StaticIntTuple[rank]) -> SIMD[type, width]:
var tmp = x.simd_load[width](i)
return tmp / 2 * (1 + erf(tmp / sqrt(2)))
print("Hello, custom GELU!")
output.for_each[func]()
return output^
The custom op must be in a module separate from where you implement your
Graph
.
Package the custom op
To package your custom op, you need a directory that includes the above Mojo
code, plus an empty __init__.mojo
file. Then, pass the directory name to the
mojo package
command.
For example, let’s say you create a directory named custom_ops
with the
two Mojo files:
custom_ops
├── __init__.mojo
└── gelu.mojo
Then, you can package it like this:
mojo package custom_ops
This creates a file called custom_ops.mojopkg
in the current directory.
Next, you simply add the op to your graph.
Build a graph with the custom op
Now you can add your custom op to a
Graph
using
ops.custom()
.
The ops.custom()
function needs your custom op name as a parameter (the name
you declared in the @register.op()
decorator) and two arguments: The op’s
input (the same as any other op you add to a graph) and the op’s output type.
The output type (out_type
arg) will almost always be
TensorType
but it might also
be ListType
. You must specify
this type so the compiler knows what kind of data and shape to expect as output
from your op.
For example, the following code builds a graph using the custom GELU op, taking
the result from a matmul operation. To specify the output type, we know that
GELU will output the same type and shape as its input, so we just use the
matmul type as the out_type
:
from max.graph import Graph, TensorType, ops
def construct_graph() -> Graph:
graph = Graph(TensorType(DType.float32, 2, 6))
matmul_constant_value = Tensor[DType.float32](TensorShape(6, 1), 0.15)
matmul_constant = graph.constant(matmul_constant_value)
matmul = graph[0] @ matmul_constant
gelu = ops.custom["my_gelu"](matmul, matmul.type())
softmax = ops.softmax(gelu)
graph.output(softmax)
return graph
That's it for adding a custom op. Now let's load the graph and run it.
Execute the graph with your custom op
Notice that when we call
load()
, we now
include the custom_ops_path
:
from max.engine import InferenceSession
from max.graph import Graph, TensorType, ops
from tensor import Tensor, TensorShape, randn
from pathlib import Path
# This is the same function from above
def construct_graph() -> Graph:
graph = Graph(TensorType(DType.float32, 2, 6))
matmul_constant_value = Tensor[DType.float32](TensorShape(6, 1), 0.15)
matmul_constant = graph.constant(matmul_constant_value)
matmul = graph[0] @ matmul_constant
gelu = ops.custom["my_gelu"](matmul, matmul.type())
softmax = ops.softmax(gelu)
graph.output(softmax)
return graph
def main():
# Load the graph with custom ops package
session = InferenceSession()
model = session.load(
construct_graph(),
custom_ops_paths=Path("custom_ops.mojopkg"),
)
# Create some sample input to run through the model:
input = randn[DType.float32]((2, 6))
results = model.execute("input0", input)
output = results.get[DType.float32]("output0")
print(output)
That’s it! You can then execute it like this:
mojo max-graph.mojo
Hello, custom GELU!
Tensor([[1.0],
[1.0]], dtype=float32, shape=2x1)
You can find the related code in our GitHub repo, under examples/extensibility.