Python class
Graph
Graph
class max.graph.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], context=None, kernel_library=None, module=None, **kwargs)
Represents a single MAX graph.
A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions.
When you instantiate a graph, you must specify the input shapes as one or
more TensorType
values. Then, build a sequence of ops and set the
graph output with output()
. For example:
from dataclasses import dataclass
import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops
@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray
def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor
linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
from dataclasses import dataclass
import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops
@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray
def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor
linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to build a graph with MAX Graph.
When creating a graph, a global sequence of chains is initialized and stored in Graph._current_chain. Every side-effecting op, e.g. buffer_load, store_buffer, load_slice_buffer, store_slice_buffer, will use the current chain to perform the op and and update Graph._current_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races.
-
Parameters:
-
- name (
str
) – A name for the graph. - forward (
Optional
[
Callable
]
) – The sequence of graph ops for the forward pass (inference). - input_types (
Iterable
[
Type
]
) – The data type(s) for the input tensor(s). - path (
Optional
[
Path
]
) – The path to a saved graph (internal use only). - custom_extensions (
list
[
Path
]
) – The extensions to load for the model. Supports paths to .mojopkg or .mojo sources with custom ops. - context (
Optional
[
mlir.Context
]
) - kernel_library (
Optional
[
KernelLibrary
]
) - module (
Optional
[
mlir.Module
]
)
- name (
add_subgraph()
add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[])
Creates and adds a subgraph to the current graph.
Creates a new Graph
instance configured as a subgraph of the current
graph. The subgraph inherits the parent graph’s MLIR context, module, and
symbolic parameters. A chain type is automatically appended to the input
types to enable proper operation sequencing within the subgraph.
The created subgraph is marked with special MLIR attributes to identify it as a subgraph and is registered in the parent graph’s subgraph registry.
-
Parameters:
-
- name (
str
) – The name identifier for the subgraph. - forward (
Callable
|
None
) – The optional callable that defines the sequence of operations for the subgraph’s forward pass. If provided, the subgraph will be built immediately using this callable. - input_types (
Iterable
[
Type
]
) – The data types for the subgraph’s input tensors. A chain type will be automatically added to these input types. - path (
Path
|
None
) – The optional path to a saved subgraph definition to load from disk instead of creating a new one. - custom_extensions (
list
[
Path
]
) – The list of paths to custom operation libraries to load for the subgraph. Supports.mojopkg
files and Mojo source directories.
- name (
-
Return type:
add_weight()
add_weight(weight, force_initial_weight_on_host=True)
Adds a weight to the graph.
If the weight is in the graph already, return the existing value.
-
Parameters:
-
Returns:
-
A
TensorValue
that contains this weight. -
Raises:
-
ValueError – If a weight with the same name already exists in the graph.
-
Return type:
current
current
inputs
The input values of the graph.
kernel_libraries_paths
Returns the list of extra kernel libraries paths for the custom ops.
local_weights_and_chain()
local_weights_and_chain()
Creates a local scope for weights and chain state modifications.
Provides a context manager that creates an isolated scope where the graph’s weights dictionary and current chain state can be modified without affecting the parent scope. Upon entering the context, the current weights and chain state are saved. Any modifications made within the context are automatically reverted when exiting the context, restoring the original state.
This is particularly useful for operations that need to temporarily modify graph state, such as building subgraphs or executing operations within isolated blocks where state changes should not persist.
output()
output(*outputs)
Sets the output nodes of the Graph
.
-
Parameters:
-
outputs (
Value
) -
Return type:
-
None
output_types
View of the types of the graph output terminator.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!