Python class
Graph
Graph
class max.graph.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], kernel_library=None, module=None, strict_device_placement=DevicePlacementPolicy.Warn, **kwargs)
Bases: object
Represents a single MAX graph.
A Graph defines a model’s computation. You build a graph by
composing operations that describe how input tensors are transformed into
outputs. Unlike imperative code that executes operations, a Graph
captures the data flow between operations, which allows MAX to optimize and
parallelize execution at compile time. Operations run on the compiled object.
The following code examples show two different strategies for constructing graphs.
Use the context manager: Use Graph as a context manager to
define the active graph. Inside the with block, retrieve inputs from
inputs, call ops to build nodes, and set the graph output with
output(). Ops called inside the block find the active graph
automatically. Ops called outside the block fail because there is no active
graph.
from max.dtype import DType
from max.graph import DeviceRef, Graph, TensorType, Weight
W = Weight("W", DType.float32, [3, 2], DeviceRef.CPU())
b = Weight("b", DType.float32, [2], DeviceRef.CPU())
with Graph(
"linear_relu",
input_types=[TensorType(DType.float32, ["batch", 3], device=DeviceRef.CPU())],
) as graph:
x = graph.inputs[0].tensor
y = x @ W + b
graph.output(y)Use the graph constructor: Pass a callable as the forward argument.
The graph automatically passes the input TensorValue to the
callable and records the return value as the graph output. Under the hood,
this still opens and closes a graph context.
from max.dtype import DType
from max.graph import DeviceRef, Graph, TensorType, TensorValue, Weight, ops
class Linear:
def __init__(self, in_dim: int, out_dim: int):
self.weight = Weight("W", DType.float32, [in_dim, out_dim], DeviceRef.CPU())
self.bias = Weight("b", DType.float32, [out_dim], DeviceRef.CPU())
def __call__(self, x: TensorValue) -> TensorValue:
return ops.matmul(x, self.weight) + self.bias
linear_layer = Linear(2, 2)
graph = Graph(
"linear",
linear_layer,
input_types=[TensorType(DType.float32, (2,), DeviceRef.CPU())],
)These examples only use the max.graph package, but most models also
use Module and other building blocks from max.nn.
To learn more, see Build a model graph with Module
</max/develop/get-started-with-max-graph-in-python>_.
-
Parameters:
-
- name (str) – A name for the graph.
- forward (Callable[..., None | Value[Any] | Iterable[Value[Any]]] | None) – The sequence of graph ops for the forward pass (inference).
- input_types (Iterable[Type[Any]]) – A sequence of
Typeinstances that describe each graph input. These are typicallyTensorTypeinstances. You can also includeBufferTypeinstances for mutable in-place inputs. - path (Path | None) – The path to a saved graph (internal use only).
- custom_extensions (Iterable[Path]) – The extensions to load for the model. Supports paths
to
.mojopkgor.mojosources with custom ops. - kernel_library (KernelLibrary | None) – Optional pre-built kernel library to use. Defaults to
None(a new library is created fromcustom_extensionsif needed). - module (mlir.Module | None) – Optional existing MLIR module (internal use only). Defaults to
None. - strict_device_placement (DevicePlacementPolicy)
add_subgraph()
add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[], devices=[])
Creates a reusable subgraph for the current graph.
A subgraph is the graph equivalent of a function: you define a block of ops once and call it from the parent graph as many times as you need. Use a subgraph when a block of computation repeats, for example, a transformer layer that appears 62 times in a model. Wrapping it in a subgraph lets the compiler process the definition once instead of once per repetition, which can cut compile time by 50x or more.
Trade-offs to keep in mind:
- Memory: Allocations inside a subgraph can’t be shared with allocations outside it, so peak memory may be slightly higher.
- Kernel fusion: The compiler can’t fuse ops across the subgraph boundary, which may reduce throughput marginally.
For models with a Module, prefer
build_subgraph(), which handles weight prefixes
automatically.
Examples:
Define a subgraph that adds 1 to every element, then call it on a graph input:
from max.dtype import DType
from max.graph import Graph, ops
from max.graph.type import TensorType, DeviceRef
input_type = TensorType(DType.float32, [10], DeviceRef.CPU())
with Graph("main", input_types=[input_type]) as graph:
with graph.add_subgraph(
"add_one", input_types=[input_type]
) as sub:
x = sub.inputs[0].tensor
one = ops.constant(1, DType.float32, device=DeviceRef.CPU())
sub.output(ops.elementwise.add(x, one))
result = ops.call(sub, graph.inputs[0])
graph.output(*result)-
Parameters:
-
- name (str) – The name identifier for the subgraph. Must be unique within
the parent graph. Use the same name when calling the subgraph
with
call(). - forward (Callable[[...], None | Value[Any] | Iterable[Value[Any]]] | None) – An optional callable that defines the subgraph’s forward pass. When provided, the subgraph is built immediately.
- input_types (Iterable[Type[Any]]) – The tensor types for the subgraph’s inputs. A chain type is added automatically for operation sequencing.
- path (Path | None) – An optional path to a saved subgraph definition to load from disk.
- custom_extensions (Iterable[Path]) – Paths to custom op libraries (
.mojopkgfiles or Mojo source directories) to load for the subgraph. - devices (Iterable[DeviceRef]) – Devices this subgraph targets.
- name (str) – The name identifier for the subgraph. Must be unique within
the parent graph. Use the same name when calling the subgraph
with
-
Returns:
-
A
Graphinstance registered as a subgraph of this graph. -
Return type:
add_weight()
add_weight(weight, force_initial_weight_on_host=True)
Adds a weight to the graph.
If the weight is in the graph already, return the existing value.
-
Parameters:
-
Returns:
-
A
TensorValuethat contains this weight. -
Raises:
-
ValueError – If a weight with the same name already exists in the graph.
-
Return type:
always_ready_chain
property always_ready_chain: _ChainValue
A graph-global, immutable chain that is always ready.
Created once per graph and never advanced/merged by the graph itself. Use it for operations that are safe to schedule without threading per-device ordering (for example, host→device transfers for staging).
current
current
device_chains
device_chains: _DeviceChainMap
empty_module()
static empty_module()
Create a new module to hold one or more graphs.
-
Return type:
-
Module
inputs
The input values of the graph.
-
Returns:
-
A sequence of
Valueobjects corresponding to theinput_typespassed at construction, excluding internal chain values.
kernel_libraries_paths
property kernel_libraries_paths: list[Path]
Returns the list of extra kernel libraries paths for the custom ops.
output()
output(*outputs)
Sets the output values of the graph and finalizes construction.
Call this once after building all ops. The graph can’t be executed
until output() has been called. Subsequent calls to
output_types read back the types of the values passed here.
Examples:
Build a graph that doubles its input and set the output:
from max.dtype import DType
from max.graph import DeviceRef, Graph, ops
from max.graph.type import TensorType
input_type = TensorType(DType.float32, [4], DeviceRef.CPU())
with Graph("double", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
two = ops.constant(2.0, DType.float32, device=DeviceRef.CPU())
graph.output(ops.elementwise.mul(x, two))-
Parameters:
-
outputs (Value[Any] | Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The output values of the graph. Each value may be a
Valueor anyTensorValueLike. -
Return type:
-
None
output_types
The types of the graph output values.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!