Build custom ops for GPUs
Mojo is our not-so-secret weapon for achieving architecture-independent performance for all types of AI workloads. Previously, only Modular engineers were able to write high-performance parallel processing operations for a MAX Graph using Mojo.
In this tutorial, you'll learn how to write custom operations (custom ops) for MAX graphs using Mojo that can execute efficiently on both CPUs and GPUs. You'll execute a graph with a custom operation and learn to create a matrix addition operation that adds one to each matrix element.
To help you get started, we provide several example custom ops that you can run with the nightly version of MAX.
Create a virtual environment
Using a virtual environment ensures that you have the MAX and Mojo version that's compatible with this project. We'll use the Magic CLI to create the environment and install the required packages.
-
If you don't have the
magic
CLI yet, you can install it on macOS and Ubuntu Linux with this command:curl -ssL https://magic.modular.com/ | bash
curl -ssL https://magic.modular.com/ | bash
Then run the
source
command that's printed in your terminal. -
Now clone the MAX GitHub repository
git clone -b nightly https://github.com/modular/max && \
cd max/examples/custom_opsgit clone -b nightly https://github.com/modular/max && \
cd max/examples/custom_ops -
You can run the custom addition operation example like this:
magic run addition
magic run addition
And the following is the expected output:
Graph result:
[[1.9635754 1.6106832 1.3523386 1.4241157 1.602606 1.2439846 1.2835392
1.1805022 1.1934654 1.3315184]
# ... shorten for brevity
Expected result:
[[1.9635754 1.6106832 1.3523386 1.4241157 1.602606 1.2439846 1.2835392
1.1805022 1.1934654 1.3315184]
# ... shorten for brevityGraph result:
[[1.9635754 1.6106832 1.3523386 1.4241157 1.602606 1.2439846 1.2835392
1.1805022 1.1934654 1.3315184]
# ... shorten for brevity
Expected result:
[[1.9635754 1.6106832 1.3523386 1.4241157 1.602606 1.2439846 1.2835392
1.1805022 1.1934654 1.3315184]
# ... shorten for brevity
Now that you've seen the code in action, let's dive into the implementation details to understand how this custom addition operation works under the hood.
Define a Mojo custom operation
The MAX Graph API represents models as computational graphs, where each operation describes parallel computations that the MAX Engine optimizes for hardware performance. Within these graphs, nodes can process any number of input tensors, perform computations on the target hardware, and generate one or more output tensors as results.
To illustrate this, open the add_one_custom.mojo
file in the
kernels
directory. Here, a custom operation called AddOneCustom
takes an input tensor,
adds one to every element, and returns the result of that computation as a new
tensor.
This custom compute node is defined as a Mojo struct:
import compiler
from utils.index import IndexList
from tensor_utils import ManagedTensorSlice, foreach
from runtime.asyncrt import MojoCallContextPtr
@compiler.register("add_one_custom", num_dps_outputs=1)
struct AddOneCustom:
@staticmethod
fn execute[
synchronous: Bool,
target: StringLiteral,
](
out: ManagedTensorSlice,
x: ManagedTensorSlice[out.type, out.rank],
ctx: MojoCallContextPtr,
):
@parameter
@always_inline
fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
return x.load[width](idx) + 1
foreach[add_one, synchronous, target](out, ctx)
import compiler
from utils.index import IndexList
from tensor_utils import ManagedTensorSlice, foreach
from runtime.asyncrt import MojoCallContextPtr
@compiler.register("add_one_custom", num_dps_outputs=1)
struct AddOneCustom:
@staticmethod
fn execute[
synchronous: Bool,
target: StringLiteral,
](
out: ManagedTensorSlice,
x: ManagedTensorSlice[out.type, out.rank],
ctx: MojoCallContextPtr,
):
@parameter
@always_inline
fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
return x.load[width](idx) + 1
foreach[add_one, synchronous, target](out, ctx)
Mojo's Single Instruction Multiple Data (SIMD) types and compile-time parameters enable hardware-agnostic parallel processing.
Inputs and outputs take the form of ManagedTensorSlice
, tensors of a specific
rank and datatype whose memory is managed outside of the operation. Elements are
read from the input tensors and written directly into the output tensors. Output
tensors by convention come first in the operation signature.
The core computation, adding one to each element in the tensor, happens in the
add_one()
function:
@parameter
@always_inline
fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
return x.load[width](idx) + 1
foreach[add_one, synchronous, target](out, ctx)
@parameter
@always_inline
fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
return x.load[width](idx) + 1
foreach[add_one, synchronous, target](out, ctx)
The foreach()
function distributes an elementwise computation in parallel
across all elements in the output tensor. This method is optimized for specific
hardware platforms, optimally distributing parallel workloads to make the most
efficient use of computational resources.
A library of these custom operations can be defined in Mojo and compiled into a
reusable package with the extension .mojopkg
. This compiled library of custom
ops can then be used by the graph compiler when defining a MAX Graph.
Add the custom operation to a graph
The MAX Graph API contains a series of pre-defined
operations written by Modular that have highly
optimized implementations. In addition to those APIs, the custom()
function
allows you to specify custom user-defined Mojo operations.
To use a Mojo custom operation with GPU acceleration, specify the custom ops in
your MAX graph. The
addition.py
example demonstrates building a computational graph in Python:
import os
from pathlib import Path
import numpy as np
from max.driver import CPU, Accelerator, Tensor, accelerator_count
from max.dtype import DType
from max.engine import InferenceSession
from max.graph import Graph, TensorType, ops
rows = 5
columns = 10
dtype = DType.float32
graph = Graph(
"addition",
forward=lambda x: ops.custom(
name="add_one_custom",
values=[x],
out_types=[TensorType(dtype=x.dtype, shape=x.tensor.shape)],
)[0].tensor,
input_types=[
TensorType(dtype, shape=[rows, columns]),
],
)
import os
from pathlib import Path
import numpy as np
from max.driver import CPU, Accelerator, Tensor, accelerator_count
from max.dtype import DType
from max.engine import InferenceSession
from max.graph import Graph, TensorType, ops
rows = 5
columns = 10
dtype = DType.float32
graph = Graph(
"addition",
forward=lambda x: ops.custom(
name="add_one_custom",
values=[x],
out_types=[TensorType(dtype=x.dtype, shape=x.tensor.shape)],
)[0].tensor,
input_types=[
TensorType(dtype, shape=[rows, columns]),
],
)
The Graph()
takes an input tensor with five rows
and ten columns, runs the custom add_one()
addition operation on it, and
returns the result.
Because MAX works across a range of hardware architectures, this same code can be run on a GPU if it is available, or a local CPU if not. For example:
device = CPU() if accelerator_count() == 0 else Accelerator()
device = CPU() if accelerator_count() == 0 else Accelerator()
Using the InferenceSession()
class, this graph is placed on whatever device we've selected:
session = InferenceSession(
devices=[device],
custom_extensions=path,
)
session = InferenceSession(
devices=[device],
custom_extensions=path,
)
This configures the inference session to run on the detected compute type.
After which MAX Engine can compile it to optimize for the target hardware:
model = session.load(graph)
model = session.load(graph)
Memory management between host CPUs and accelerator devices is handled through
the MAX Driver API. This interface gives you precise control over memory
transfers, allowing you to optimize performance by explicitly managing these
potentially expensive operations. The API's
Tensor
class is designed for seamless
integration with common Python frameworks - it offers zero-copy interoperability
with both NumPy arrays and PyTorch tensors. Here's how we can leverage this to
create a MAX Tensor from random data:
x_array = np.random.uniform(size=(rows, columns)).astype(np.float32)
x = Tensor.from_numpy(x_array)
x_array = np.random.uniform(size=(rows, columns)).astype(np.float32)
x = Tensor.from_numpy(x_array)
This Tensor is resident on the host and needs to be moved to the accelerator to be ready for use with the MAX Graph on that device. Note that if the device is the host CPU, this is a no-op:
x = x.to(device)
x = x.to(device)
This Tensor can now be run through our compiled graph, and a device-resident tensor is the result:
result = model.execute(x)[0]
result = model.execute(x)[0]
To examine the results, this Tensor can be moved back to the host:
result = result.to(CPU())
result = result.to(CPU())
Then you can convert it back to a NumPy array:
print(result.to_numpy())
print(result.to_numpy())
For a more advanced example, be sure to check out how we compute the Mandelbrot
set using the
ComplexSIMD
data type and a
vectorized implementation of the fractal computation.
As a final note, the programming interface described above is being provided as a preview, and some elements will change as we continue to improve the GPU programming experience in MAX.
More to come
Mojo is an incredible language for programming accelerators: Python-like high-level syntax, systems language performance, and unique language features designed for modern heterogeneous computation. We're tremendously excited to be able to show off how it enables MAX to drive forward the state-of-the-art when running AI workloads and more on GPUs. Adding custom ops to a graph is our first introduction to how you can program GPUs with Mojo. These are early examples, and we will be rolling out more API documentation and examples. To stay up to date with new releases, sign up for our newsletter, check out the community, and join our forum.
The nightly branch of the open-source MAX repository contains everything needed to run the examples above on an Ampere- or Lovelace-class NVIDIA GPU (more to come!), as well as on a local CPU. Give them a try today to start experimenting with programming GPUs in Mojo!
Did this tutorial work for you?
Thank you! We'll create more content like this.
Thank you for helping us improve!