Skip to main content

Python class

Graph

Graph

class max.graph.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], kernel_library=None, module=None, strict_device_placement=DevicePlacementPolicy.Warn, **kwargs)

source

Bases: object

Represents a single MAX graph.

A Graph defines a model’s computation. You build a graph by composing operations that describe how input tensors are transformed into outputs. Unlike imperative code that executes operations, a Graph captures the data flow between operations, which allows MAX to optimize and parallelize execution at compile time. Operations run on the compiled object.

The following code examples show two different strategies for constructing graphs.

Use the context manager: Use Graph as a context manager to define the active graph. Inside the with block, retrieve inputs from inputs, call ops to build nodes, and set the graph output with output(). Ops called inside the block find the active graph automatically. Ops called outside the block fail because there is no active graph.

from max.dtype import DType
from max.graph import DeviceRef, Graph, TensorType, Weight

W = Weight("W", DType.float32, [3, 2], DeviceRef.CPU())
b = Weight("b", DType.float32, [2], DeviceRef.CPU())

with Graph(
    "linear_relu",
    input_types=[TensorType(DType.float32, ["batch", 3], device=DeviceRef.CPU())],
) as graph:
    x = graph.inputs[0].tensor
    y = x @ W + b
    graph.output(y)

Use the graph constructor: Pass a callable as the forward argument. The graph automatically passes the input TensorValue to the callable and records the return value as the graph output. Under the hood, this still opens and closes a graph context.

from max.dtype import DType
from max.graph import DeviceRef, Graph, TensorType, TensorValue, Weight, ops

class Linear:
    def __init__(self, in_dim: int, out_dim: int):
        self.weight = Weight("W", DType.float32, [in_dim, out_dim], DeviceRef.CPU())
        self.bias = Weight("b", DType.float32, [out_dim], DeviceRef.CPU())

    def __call__(self, x: TensorValue) -> TensorValue:
        return ops.matmul(x, self.weight) + self.bias

linear_layer = Linear(2, 2)

graph = Graph(
    "linear",
    linear_layer,
    input_types=[TensorType(DType.float32, (2,), DeviceRef.CPU())],
)

These examples only use the max.graph package, but most models also use Module and other building blocks from max.nn. To learn more, see Build a model graph with Module </max/develop/get-started-with-max-graph-in-python>_.

Parameters:

  • name (str) – A name for the graph.
  • forward (Callable[..., None | Value[Any] | Iterable[Value[Any]]] | None) – The sequence of graph ops for the forward pass (inference).
  • input_types (Iterable[Type[Any]]) – A sequence of Type instances that describe each graph input. These are typically TensorType instances. You can also include BufferType instances for mutable in-place inputs.
  • path (Path | None) – The path to a saved graph (internal use only).
  • custom_extensions (Iterable[Path]) – The extensions to load for the model. Supports paths to .mojopkg or .mojo sources with custom ops.
  • kernel_library (KernelLibrary | None) – Optional pre-built kernel library to use. Defaults to None (a new library is created from custom_extensions if needed).
  • module (mlir.Module | None) – Optional existing MLIR module (internal use only). Defaults to None.
  • strict_device_placement (DevicePlacementPolicy)

add_subgraph()

add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[], devices=[])

source

Creates a reusable subgraph for the current graph.

A subgraph is the graph equivalent of a function: you define a block of ops once and call it from the parent graph as many times as you need. Use a subgraph when a block of computation repeats, for example, a transformer layer that appears 62 times in a model. Wrapping it in a subgraph lets the compiler process the definition once instead of once per repetition, which can cut compile time by 50x or more.

Trade-offs to keep in mind:

  • Memory: Allocations inside a subgraph can’t be shared with allocations outside it, so peak memory may be slightly higher.
  • Kernel fusion: The compiler can’t fuse ops across the subgraph boundary, which may reduce throughput marginally.

For models with a Module, prefer build_subgraph(), which handles weight prefixes automatically.

Examples:

Define a subgraph that adds 1 to every element, then call it on a graph input:

from max.dtype import DType
from max.graph import Graph, ops
from max.graph.type import TensorType, DeviceRef

input_type = TensorType(DType.float32, [10], DeviceRef.CPU())

with Graph("main", input_types=[input_type]) as graph:
    with graph.add_subgraph(
        "add_one", input_types=[input_type]
    ) as sub:
        x = sub.inputs[0].tensor
        one = ops.constant(1, DType.float32, device=DeviceRef.CPU())
        sub.output(ops.elementwise.add(x, one))

    result = ops.call(sub, graph.inputs[0])
    graph.output(*result)

Parameters:

  • name (str) – The name identifier for the subgraph. Must be unique within the parent graph. Use the same name when calling the subgraph with call().
  • forward (Callable[[...], None | Value[Any] | Iterable[Value[Any]]] | None) – An optional callable that defines the subgraph’s forward pass. When provided, the subgraph is built immediately.
  • input_types (Iterable[Type[Any]]) – The tensor types for the subgraph’s inputs. A chain type is added automatically for operation sequencing.
  • path (Path | None) – An optional path to a saved subgraph definition to load from disk.
  • custom_extensions (Iterable[Path]) – Paths to custom op libraries (.mojopkg files or Mojo source directories) to load for the subgraph.
  • devices (Iterable[DeviceRef]) – Devices this subgraph targets.

Returns:

A Graph instance registered as a subgraph of this graph.

Return type:

Graph

add_weight()

add_weight(weight, force_initial_weight_on_host=True)

source

Adds a weight to the graph.

If the weight is in the graph already, return the existing value.

Parameters:

  • weight (Weight) – The weight to add to the graph.
  • force_initial_weight_on_host (bool) – If true, then forces weights to initially be allocated on host before being moved to the indicated device. This is needed as a stop gap until we have a more fleshed out ownership model of external constants.

Returns:

A TensorValue that contains this weight.

Raises:

ValueError – If a weight with the same name already exists in the graph.

Return type:

TensorValue

always_ready_chain

property always_ready_chain: _ChainValue

source

A graph-global, immutable chain that is always ready.

Created once per graph and never advanced/merged by the graph itself. Use it for operations that are safe to schedule without threading per-device ordering (for example, host→device transfers for staging).

current

current

source

device_chains

device_chains: _DeviceChainMap

source

empty_module()

static empty_module()

source

Create a new module to hold one or more graphs.

Return type:

Module

inputs

property inputs: Sequence[Value[Any]]

source

The input values of the graph.

Returns:

A sequence of Value objects corresponding to the input_types passed at construction, excluding internal chain values.

kernel_libraries_paths

property kernel_libraries_paths: list[Path]

source

Returns the list of extra kernel libraries paths for the custom ops.

output()

output(*outputs)

source

Sets the output values of the graph and finalizes construction.

Call this once after building all ops. The graph can’t be executed until output() has been called. Subsequent calls to output_types read back the types of the values passed here.

Examples:

Build a graph that doubles its input and set the output:

from max.dtype import DType
from max.graph import DeviceRef, Graph, ops
from max.graph.type import TensorType

input_type = TensorType(DType.float32, [4], DeviceRef.CPU())

with Graph("double", input_types=[input_type]) as graph:
    x = graph.inputs[0].tensor
    two = ops.constant(2.0, DType.float32, device=DeviceRef.CPU())
    graph.output(ops.elementwise.mul(x, two))

Parameters:

outputs (Value[Any] | Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The output values of the graph. Each value may be a Value or any TensorValueLike.

Return type:

None

output_types

property output_types: list[Type[Any]]

source

The types of the graph output values.

Returns:

A list of Type objects corresponding to the values passed to output(), in the same order.

Raises:

TypeError – If the graph has not yet been terminated by a call to output().