Skip to main content
Log in

Python class

Graph

Graph

class max.graph.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], context=None, kernel_library=None, module=None, **kwargs)

Represents a single MAX graph.

A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions.

When you instantiate a graph, you must specify the input shapes as one or more TensorType values. Then, build a sequence of ops and set the graph output with output(). For example:

from dataclasses import dataclass

import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops

@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray

def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor

linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
from dataclasses import dataclass

import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops

@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray

def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor

linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)

You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to build a graph with MAX Graph.

When creating a graph, a global sequence of chains is initialized and stored in Graph._current_chain. Every side-effecting op, e.g. buffer_load, store_buffer, load_slice_buffer, store_slice_buffer, will use the current chain to perform the op and and update Graph._current_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races.

Parameters:

  • name (str )
  • forward (Optional [ Callable ] )
  • input_types (Iterable [ Type ] )
  • path (Optional [ Path ] )
  • custom_extensions (list [ Path ] )
  • context (Optional [ mlir.Context ] )
  • kernel_library (Optional [ KernelLibrary ] )
  • module (Optional [ mlir.Module ] )

add_subgraph()

add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[])

Parameters:

Return type:

Graph

add_weight()

add_weight(weight, force_initial_weight_on_host=True)

Adds a weight to the graph.

If the weight is in the graph already, return the existing value.

Parameters:

  • weight (Weight ) – The weight to add to the graph.
  • force_initial_weight_on_host (bool ) – If true, then forces weights to initially be allocated on host before being moved to the indicated device. This is needed as a stop gap until we have a more fleshed out ownership model of external constants.

Returns:

A TensorValue that contains this weight.

Raises:

ValueError – If a weight with the same name already exists in the graph.

Return type:

TensorValue

current

current

inputs

property inputs*: Sequence[Value]*

The input values of the graph.

kernel_libraries_paths

property kernel_libraries_paths*: list[Path]*

Returns the list of extra kernel libraries paths for the custom ops.

local_weights_and_chain()

local_weights_and_chain()

output()

output(*outputs)

Sets the output nodes of the Graph.

Parameters:

outputs (Value )

Return type:

None

output_types

property output_types*: list[Type]*

View of the types of the graph output terminator.