Skip to main content

Python class

Graph

Graph

class max.graph.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], context=None, kernel_library=None, module=None, **kwargs)

Represents a single MAX graph.

A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions.

When you instantiate a graph, you must specify the input shapes as one or more TensorType values. Then, build a sequence of ops and set the graph output with output(). For example:

from dataclasses import dataclass

import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops

@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray

def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor

linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
from dataclasses import dataclass

import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops

@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray

def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor

linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)

You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to build a graph with MAX Graph.

When creating a graph, a global sequence of chains is initialized and stored in Graph._current_chain. Every side-effecting op, e.g. buffer_load, store_buffer, load_slice_buffer, store_slice_buffer, will use the current chain to perform the op and and update Graph._current_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races.

Parameters:

  • name (str ) – A name for the graph.
  • forward (Optional [ Callable ] ) – The sequence of graph ops for the forward pass (inference).
  • input_types (Iterable [ Type ] ) – The data type(s) for the input tensor(s).
  • path (Optional [ Path ] ) – The path to a saved graph (internal use only).
  • custom_extensions (list [ Path ] ) – The extensions to load for the model. Supports paths to .mojopkg or .mojo sources with custom ops.
  • context (Optional [ mlir.Context ] )
  • kernel_library (Optional [ KernelLibrary ] )
  • module (Optional [ mlir.Module ] )

add_subgraph()

add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[])

Creates and adds a subgraph to the current graph.

Creates a new Graph instance configured as a subgraph of the current graph. The subgraph inherits the parent graph’s MLIR context, module, and symbolic parameters. A chain type is automatically appended to the input types to enable proper operation sequencing within the subgraph.

The created subgraph is marked with special MLIR attributes to identify it as a subgraph and is registered in the parent graph’s subgraph registry.

Parameters:

  • name (str ) – The name identifier for the subgraph.
  • forward (Callable | None ) – The optional callable that defines the sequence of operations for the subgraph’s forward pass. If provided, the subgraph will be built immediately using this callable.
  • input_types (Iterable [ Type ] ) – The data types for the subgraph’s input tensors. A chain type will be automatically added to these input types.
  • path (Path | None ) – The optional path to a saved subgraph definition to load from disk instead of creating a new one.
  • custom_extensions (list [ Path ] ) – The list of paths to custom operation libraries to load for the subgraph. Supports .mojopkg files and Mojo source directories.

Return type:

Graph

add_weight()

add_weight(weight, force_initial_weight_on_host=True)

Adds a weight to the graph.

If the weight is in the graph already, return the existing value.

Parameters:

  • weight (Weight ) – The weight to add to the graph.
  • force_initial_weight_on_host (bool ) – If true, then forces weights to initially be allocated on host before being moved to the indicated device. This is needed as a stop gap until we have a more fleshed out ownership model of external constants.

Returns:

A TensorValue that contains this weight.

Raises:

ValueError – If a weight with the same name already exists in the graph.

Return type:

TensorValue

current

current

inputs

property inputs: Sequence[Value]

The input values of the graph.

kernel_libraries_paths

property kernel_libraries_paths: list[Path]

Returns the list of extra kernel libraries paths for the custom ops.

local_weights_and_chain()

local_weights_and_chain()

Creates a local scope for weights and chain state modifications.

Provides a context manager that creates an isolated scope where the graph’s weights dictionary and current chain state can be modified without affecting the parent scope. Upon entering the context, the current weights and chain state are saved. Any modifications made within the context are automatically reverted when exiting the context, restoring the original state.

This is particularly useful for operations that need to temporarily modify graph state, such as building subgraphs or executing operations within isolated blocks where state changes should not persist.

output()

output(*outputs)

Sets the output nodes of the Graph.

Parameters:

outputs (Value )

Return type:

None

output_types

property output_types: list[Type]

View of the types of the graph output terminator.