Skip to main content
Log in

Python module

torch

CustomOpLibrary

class max.torch.CustomOpLibrary(kernel_library: Path | KernelLibrary)

A PyTorch interface to custom operations implemented in Mojo.

This API allows for easy passing of PyTorch data as torch.Tensor values to the corresponding custom op. CustomOpLibrary handles the compilation of the Mojo custom ops and marshalling of data between PyTorch and the executable Mojo code.

For example, consider a grayscale operation implemented in Mojo:

my_library/grayscale.mojo
 @register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[type = DType.uint8, rank=2],
img_in: InputTensor[type = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...
 @register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[type = DType.uint8, rank=2],
img_in: InputTensor[type = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...

You can then use CustomOpLibrary to invoke the Mojo operation like so:

import torch
from max.torch import CustomOpLibrary

op_library = CustomOpLibrary("my_library")
grayscale_op = op_library.grayscale

def grayscale(pic: torch.Tensor) -> torch.Tensor:
result = pic.new_empty(pic.shape[:-1])
grayscale_op(result, pic)
return result

img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)
import torch
from max.torch import CustomOpLibrary

op_library = CustomOpLibrary("my_library")
grayscale_op = op_library.grayscale

def grayscale(pic: torch.Tensor) -> torch.Tensor:
result = pic.new_empty(pic.shape[:-1])
grayscale_op(result, pic)
return result

img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)

The custom operation produced by op_library.<opname> will have the same interface as the backing Mojo operation. Each InputTensor or OutputTensor argument corresponds to a torch.Tensor value in Python. Each argument corresponding to an OutputTensor in the Mojo operation will be modified in-place.

  • Parameters:

    kernel_library – The path to a .mojo file or a .mojopkg with your custom op kernels, or the corresponding library object.