Python function
graph_op
graph_op()
max.experimental.torch.graph_op(fn: Callable[[...], Value[Any] | None], name: str | None = None, kernel_library: Path | KernelLibrary | None = None, input_types: Sequence[TensorType] | None = None, output_types: Sequence[TensorType] | None = None, num_outputs: int | None = None) → CustomOpDef
max.experimental.torch.graph_op(fn: None = None, name: str | None = None, kernel_library: Path | KernelLibrary | None = None, input_types: Sequence[TensorType] | None = None, output_types: Sequence[TensorType] | None = None, num_outputs: int | None = None) → Callable[[Callable[[...], Iterable[Value[Any]] | Value[Any] | None]], CustomOpDef]
A decorator to create PyTorch custom operations using MAX graph operations.
This decorator allows you to define larger graphs using the functions in
max.graph.ops or max.nn modules and
call them with PyTorch tensors, or integrate them into PyTorch modules.
These custom ops can be called eagerly, and support compilation with
torch.compile and the Inductor backend.
The resulting custom operation uses destination-passing style, where output tensors are passed as the first arguments and modified in-place. This allows PyTorch to manage the memory and streams of the output tensors. Tensors internal to the computation are managed via MAX’s graph compiler and memory planning.
The default behavior is to JIT-compile for the specific input and output
shapes needed. If you are passing variable-sized inputs, for instance a
batch size or sequence length which may take on many different values
between calls, you should specify this dimension as a symbolic dimension
through input_types and output_types. Otherwise you will
end up compiling specialized graphs for each possible variation of
inputs, which may use a lot of memory.
If neither output_types nor num_outputs is specified, default to 1 output.
For example to create a functional-style PyTorch op backed by MAX:
import torch
import numpy as np
import max.experimental.torch
from max.dtype import DType
from max.graph import ops
@max.experimental.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue):
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)
@torch.compile
def grayscale(pic: torch.Tensor):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output
device = "cuda" if torch.cuda.is_available() else "cpu"
img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)
print(f"Input shape: {img.shape}")
print(f"Output shape: {result.shape}")
print("Grayscale conversion completed successfully!")-
Parameters:
-
- fn – The function to decorate. If None, returns a decorator.
- name – Optional name for the custom operation. Defaults to the function name.
- kernel_library – Optional kernel library to use for compilation. Useful for creating graphs with custom Mojo ops.
- input_types – Optional sequence of input tensor types for compilation. If None, types are inferred from runtime arguments.
- output_types – Optional sequence of output tensor types for compilation. If None, types are inferred from runtime arguments.
- num_outputs – The number of outputs of the graph. We need to know this ahead of time to register with PyTorch before we’ve compiled the final kernels.
-
Returns:
-
A PyTorch custom operation that can be called with torch.Tensor arguments.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!