Skip to main content

Python module

max.graph.ops

Implements operations used when staging a graph.

This module provides operations for building a Graph in MAX. Most operations return a TensorValue, which supports standard Python operators such as +, *, and @ (matrix multiplication), as well as convenience methods like reshape() and flatten(). Ops like constant() can also add constant values to your graph.

When an operation receives inputs with different data types (DType), MAX promotes the output to a common type by picking the higher-ranked category (bool < unsigned int < signed int < float) and the larger bit width. The result is always one of the input types. Plainly, the promotion rule for two values x and y is:

max(category(x), category(y)), max(bitwidth(x), bitwidth(y))

If any input can’t be safely represented in the chosen type, MAX raises an error. For example, MAX fails to promote uint8 and int8 to int8, since int8 can’t represent all uint8 values.

abs()

max.graph.ops.abs(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

acos()

max.graph.ops.acos(x)

source

Computes the arccosine (inverse cosine) of the input tensor.

Returns values in the range [0, π] for inputs in [-1, 1].

Creates a new op node to compute the elementwise arccosine of a symbolic tensor and adds it to the graph, returning the symbolic result.

def acos_graph():
    input_type = TensorType(dtype=DType.float32, shape=(3,), device=DeviceRef.CPU())

    with Graph("acos_graph", input_types=(input_type,)) as graph:
        x = graph.inputs[0]
        out = ops.acos(x)
        graph.output(out)

Parameters:

x (TensorValue) – Input tensor with values in [-1, 1]. If values are outside this domain, they will be clamped to the valid range.

Returns:

  • the same dtype as the input
  • the same shape as the input

Return type:

Arccosine of the input in radians [0, π]. The result will have

Raises:

  • Error – If the symbol doesn’t represent a tensor value.
  • Error – If the input is not a floating-point dtype.

add()

max.graph.ops.add(lhs, rhs)

source

Parameters:

Return type:

TensorValue

allgather()

max.graph.ops.allgather(inputs, signal_buffers, axis=0)

source

Collective allgather operation.

This op is a collective op which takes in tensors from different devices and outputs tensors on different devices. In particular, this operation will gather the inputs across different devices and concatenates them along the specified dimension. The result is then broadcasted back to the same devices that the inputs came from.

Parameters:

Returns:

An iterable outputs which all hold the gathered output. Each output tensor contains the concatenation of all inputs along the specified dimension.

Return type:

list[TensorValue]

argmax()

max.graph.ops.argmax(x, axis=-1)

source

Reduces a symbolic tensor using an argmax operation.

When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor for the operation.
  • axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.

Returns:

A symbolic tensor representing the result of the argmax operation. The tensor will have the same rank as the input tensor, and the same shape except along the axis dimension which will have size 1.

Return type:

TensorValue

argmin()

max.graph.ops.argmin(x, axis=-1)

source

Reduces a symbolic tensor using an argmin operation.

When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor for the operation.
  • axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.

Returns:

A symbolic tensor representing the result of the argmin operation. The tensor will have the same rank as the input tensor, and the same shape except along the axis dimension which will have size 1.

Return type:

TensorValue

argsort()

max.graph.ops.argsort(x, ascending=True)

source

Returns the indices that would sort a tensor.

This function returns the indices that would sort the input tensor along its first dimension. The returned indices are of type int64.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – Input tensor to be sorted.
  • ascending (bool) – If True (default), sort in ascending order. If False, sort in descending order.

Returns:

A tensor of indices of the same shape as the input tensor.

Return type:

TensorValue

as_interleaved_complex()

max.graph.ops.as_interleaved_complex(x)

source

Reshapes the input symbolic tensor as complex from alternating (real, imag).

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor representing complex numbers as alternating pairs of (real, imag) real-valued numbers. Its last dimension must have an even size.

Returns:

A symbolic tensor representing the complex-valued tensor, but with the values pulled out as complex numbers. The result has the same dimensions for all dimensions except the last dimension, which is halved, and then a final dimension of size 2 representing the complex value.

Return type:

TensorValue

assert_same_device()

max.graph.ops.assert_same_device(*values, **named_values)

source

Raises ValueError if any of the given values are not on the same device.

Parameters:

Return type:

None

atanh()

max.graph.ops.atanh(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

avg_pool2d()

max.graph.ops.avg_pool2d(input, kernel_size, stride=1, dilation=1, padding=0, ceil_mode=False, count_boundary=True)

source

Perform a 2D average pooling operation on the input tensor.

Applies a 2D average pooling operation to the input tensor with layout [N, H, W, C]. The pooling operation slides a window of size kernel_size over the spatial dimensions and computes the average value within each window.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape [N, H, W, C].
  • kernel_size (tuple[int | str | Dim | integer[Any] | TypedAttr, int | str | Dim | integer[Any] | TypedAttr]) – The height and width of the sliding window.
  • stride (int | tuple[int, int]) – The stride of the sliding window. Can be a single integer applied to both spatial dimensions or a tuple (stride_h, stride_w). Defaults to 1.
  • dilation (int | tuple[int, int]) – The spacing between kernel elements. Can be a single integer or a tuple (dilation_h, dilation_w). Defaults to 1.
  • padding (int | tuple[int, int]) – Zero-padding added to both sides of each spatial dimension. Can be a single integer or a tuple (pad_h, pad_w). Defaults to 0.
  • ceil_mode (bool) – If True, uses ceil instead of floor when computing the output spatial shape. Defaults to False.
  • count_boundary (bool) – If True, includes padding elements in the divisor when computing the average. Defaults to True.

Returns:

A symbolic tensor with the average pooling applied, with shape [N, H_out, W_out, C].

Return type:

TensorValue

band_part()

max.graph.ops.band_part(x, num_lower=None, num_upper=None, exclude=False)

source

Masks out everything except a diagonal band of an input matrix.

Copies a tensor setting everything outside the central diagonal band of the matrices to zero, where all but the last two axes are effectively batches, and the last two axes define sub matrices.

Assumes the input has dimensions [I, J, …, M, N], then the output tensor has the same shape as the input, and the values are given by

out[i, j, ..., m, n] = in_band(m, n) * input[i, j,  ..., m, n].

with the indicator function:

in_band(m, n) = ((num_lower is None || (m - n) <= num_lower)) &&
                (num_upper is None || (n - m) <= num_upper))

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to mask.
  • num_lower (int | None) – The number of diagonal bands to include below the central diagonal. If None, include the entire lower triangle.
  • num_upper (int | None) – The number of diagonal bands to include above the central diagonal. If None, include the entire upper triangle.
  • exclude (bool) – If true, invert the selection of elements to mask. Elements in the band are set to zero.

Returns:

A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor.

Raises:

ValueError – If the input tensor rank is less than 2, or if num_lower/num_upper are out of bounds for statically known dimensions.

Return type:

TensorValue

bottom_k()

max.graph.ops.bottom_k(input, k, axis=-1)

source

Returns tensor with only the bottom K values along given axis.

Parameters:

Returns:

Bottom K values (ascending), Bottom K indices.

Return type:

tuple[TensorValue, TensorValue]

broadcast_to()

max.graph.ops.broadcast_to(x, shape, out_dims=None)

source

Broadcasts a symbolic tensor.

Broadcasts the input tensor to the specified shape. Dimensions in the input must be one or match the target dimension.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – The input symbolic tensor to broadcast. This tensor may not contain any dynamic dimensions.
  • shape (TensorValue | Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The new shape as a list of dimensions. Dynamic dimensions are not allowed.
  • out_dims (Iterable[int | str | Dim | integer[Any] | TypedAttr] | None) – Output dims used only for tensor-valued shape.

Returns:

A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as shape.

Raises:

ValueError – if a tensor-valued shape is passed without out_dims.

Return type:

TensorValue

buffer_create()

max.graph.ops.buffer_create(type)

source

Creates a buffer of the given type.

Parameters:

type (BufferType) – The type of the resulting BufferValue

Returns:

A new BufferValue of the requested type.

Return type:

BufferValue

buffer_load()

max.graph.ops.buffer_load(x)

source

Loads the input buffer into a tensor.

It loads the in-place mutable tensor to an immutable tensor graph value. This is semantically equivalent to a copy from the mutable tensor x to the mutable value-semantic tensor output.

Parameters:

x (BufferValue) – The buffer to be loaded to a tensor.

Returns:

A tensor graph value representing a copy of the buffer loaded.

Return type:

TensorValue

buffer_store()

max.graph.ops.buffer_store(destination, source)

source

Stores the input tensor into the in-out buffer.

It stores the immutable source tensor in the mutable destination buffer. This is semantically equivalent to a copy from source to destination.

Parameters:

Return type:

None

buffer_store_slice()

max.graph.ops.buffer_store_slice(destination, source, indices)

source

Stores the input tensor to into a slice in the input buffer.

It stores the immutable input tensor source in the mutable tensor destination. This is semantically equivalent to a copy from source tensor to a slice in the destination buffer at index specified by indices.

Parameters:

Return type:

None

call()

max.graph.ops.call(graph, *args, prefix='')

source

Calls a previously defined graph with the provided arguments.

Use this function to invoke a subgraph built with add_subgraph() or build_subgraph(). The primary benefit is that the compiler processes the subgraph definition once, which reduces compile time significantly for models with repeated blocks.

Examples:

Call a subgraph and forward its outputs to the parent graph:

from max.dtype import DType
from max.graph import Graph, ops
from max.graph.type import TensorType, DeviceRef

input_type = TensorType(DType.float32, [10], DeviceRef.CPU())

with Graph("main", input_types=[input_type]) as graph:
    with graph.add_subgraph(
        "add_one", input_types=[input_type]
    ) as sub:
        x = sub.inputs[0].tensor
        one = ops.constant(1, DType.float32, device=DeviceRef.CPU())
        sub.output(ops.elementwise.add(x, one))

    result = ops.call(sub, graph.inputs[0])
    graph.output(*result)

Call a shared subgraph for each layer of a model, resolving different weights at each call site with prefix:

# Build the subgraph once from the first layer.
subgraph = self.layers[0].build_subgraph(
    "transformer_block",
    input_types=input_types,
    weight_prefix="layers.0.",
)

# Invoke it once per layer with layer-specific weights.
for idx in range(num_layers):
    outputs = ops.call(
        subgraph, *h, prefix=f"layers.{idx}."
    )

Parameters:

  • graph (Graph) – The subgraph to call.
  • *args (Value[Any]) – Arguments to pass to the subgraph. Must match the subgraph’s input types, excluding the chain value (handled internally).
  • prefix (str) – A string prepended to all weight names when the subgraph is invoked. Use this to distinguish repeated calls to the same subgraph. For example, if a transformer block references a weight named attention.wq, calling with prefix="layers.3." resolves it to layers.3.attention.wq in the weights registry. Leave empty if the subgraph contains no placeholder weights.

Returns:

A list of Value objects representing the subgraph’s outputs, excluding any internal chain values.

Return type:

list[Value[Any]]

cast()

max.graph.ops.cast(x, dtype)

source

Casts a symbolic tensor to a different data type.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – The input tensor to cast.
  • dtype (DType) – The target dtype to which the tensor is cast.

Returns:

A new symbolic tensor with the same shape as the input and the specified dtype.

Return type:

TensorValue

chunk()

max.graph.ops.chunk(x, chunks, axis=0)

source

Chunk the tensor into an exact number of chunks along the specified dim.

Example:

>>> a = TensorValue([1, 2, 3, 4, 5])
>>> chunk(a, 2, 0)
[TensorValue([1, 2]), TensorValue([3, 4])]

Parameters:

Returns:

A list of chunks tensors.

Return type:

list[TensorValue]

concat()

max.graph.ops.concat(original_vals, axis=0)

source

Concatenates a list of symbolic tensors along an axis.

Joins multiple tensors along a specified dimension. This operation requires the functional API since it operates on multiple tensors. All input tensors must have the same rank and the same size in all dimensions except the concatenation axis.

import max.experimental.functional as F
from max.experimental.tensor import Tensor

# Create two 2x2 matrices
a = Tensor.constant([[1, 2], [3, 4]])
b = Tensor.constant([[5, 6], [7, 8]])

# Concatenate along axis 0 (rows) - stacks vertically
vertical = F.concat([a, b], axis=0)
print(f"Concatenated along axis 0: {vertical.shape}")
# Output: Concatenated along axis 0: [Dim(4), Dim(2)]
print(vertical)
# [[1, 2],
#  [3, 4],
#  [5, 6],
#  [7, 8]]

# Concatenate along axis 1 (columns) - joins horizontally
horizontal = F.concat([a, b], axis=1)
print(f"Concatenated along axis 1: {horizontal.shape}")
# Output: Concatenated along axis 1: [Dim(2), Dim(4)]
print(horizontal)
# [[1, 2, 5, 6],
#  [3, 4, 7, 8]]

Parameters:

  • original_vals (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – The list of symbolic tensor values to concatenate. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension other than axis.
  • axis (int) – The axis to concatenate along. If negative, indexes relative to the end of the tensor shape. For instance, concat(vs, -1) will concatenate along the last dimension.

Returns:

A new symbolic tensor representing the concatenation result. It will have the same rank as each input tensor, and its dimensions will be the same as each input tensor’s for each dimension other than axis, which will have size equal to the sum of all tensor’s size for that dimension.

Return type:

TensorValue

cond()

max.graph.ops.cond(pred, out_types, then_fn, else_fn)

source

Conditionally execute one of two branches based on a boolean predicate.

This function provides conditional execution in the computation graph, where one of two branches is executed based on the runtime value of a boolean predicate. Both branches must return the same number and types of values as specified in out_types. Buffer mutations in branches are tracked automatically through the chain mechanism.

The predicate is evaluated at runtime to determine which branch to execute. Both branches are compiled but only the selected branch is executed based on the predicate value.

This example shows a basic conditional with return values:

def then_fn():
    return ops.constant(1, DType.int32, device=DeviceRef.CPU())

def else_fn():
    return ops.constant(0, DType.int32, device=DeviceRef.CPU())

device = DeviceRef.CPU()
pred = ops.constant(True, DType.bool, device=device)
result = ops.cond(
    pred,
    [TensorType(DType.int32, [], device=device)],
    then_fn,
    else_fn
)

This example shows a conditional with buffer mutations, where branches don’t return values:

def then_fn():
    ops.inplace_custom("increment", device=buffer.device, values=[buffer])

def else_fn():
    ops.inplace_custom("decrement", device=buffer.device, values=[buffer])

ops.cond(pred, None, then_fn, else_fn)

This example shows a conditional with multiple return values:

def then_fn():
    a = ops.constant(1, DType.float32, device=device)
    b = ops.constant(2, DType.float32, device=device)
    return a, b

def else_fn():
    a = ops.constant(0, DType.float32, device=device)
    b = ops.constant(-1, DType.float32, device=device)
    return a, b

device = DeviceRef.CPU()
out_types = [
    TensorType(DType.float32, [], device=device),
    TensorType(DType.float32, [], device=device)
]
results = ops.cond(pred, out_types, then_fn, else_fn)

Parameters:

Returns:

List of output values from executed branch. Returns empty list when out_types is None.

Raises:

ValueError – If branches return different numbers of results or result types don’t match out_types.

Return type:

list[TensorValue]

constant()

max.graph.ops.constant(value, dtype=None, device=None)

source

Adds a node representing a constant operation.

The value of this constant will have the type TensorType with the same shape as value. If value is a scalar type, it will create a TensorType with 0 dimensions.

The constant will be loaded with the specified dtype. If the constant does not fit within the specified dtype, an error is raised.

Warning: Loading the constant could result in precision loss. For example, loading 16777217 as a float32 will result in 16777216.0.

Parameters:

Returns:

A graph value containing the constant data as an attribute.

Return type:

TensorValue

constant_external()

max.graph.ops.constant_external(name, type)

source

Registers an external constant (weight) in the graph of a given type.

Two external constants with the same name and type refer to the same weight.

Two external constants with the same name and different types are incompatible and will fail compilation.

Parameters:

  • name (str) – The name of the external constant. This should be the fully-qualified weight name and must be unique.
  • type (TensorType) – The type of the constant value.

Returns:

A tensor value of the specified type, representing the weight value associated with the name at compile time.

Return type:

TensorValue

conv2d()

max.graph.ops.conv2d(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), groups=1, bias=None, input_layout=ConvInputLayout.NHWC, filter_layout=FilterLayout.RSCF)

source

Computes the 2-D convolution product of the input with the given filter, bias, strides, dilations, paddings, and groups.

The op supports 2-D convolution, with the following layout assumptions:

  • input x has NHWC layout, i.e., (batch_size, height, width, in_channels)
  • filter has layout RSCF, i.e., (height, width, in_channels / num_groups, out_channels)
  • bias has shape (out_channels,)

The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2-D convolution, dim1 here represents H and dim2 represents W. In Python like syntax, padding a 2x3 spatial input with [0, 1, 2, 1] would yield:

input = [
  [1, 2, 3],
  [4, 5, 6]
]
# Shape is 2x3

padded_input = [
  [0, 0, 1, 2, 3, 0],
  [0, 0, 4, 5, 6, 0],
  [0, 0, 0, 0, 0, 0]
]
# Shape is 3x6

This op currently only supports strides and padding on the input.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An NHWC input tensor to perform the convolution upon.
  • filter (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The convolution filter in RSCF layout: (height, width, in_channels / num_groups, out_channels).
  • stride (tuple[int, int]) – The stride of the convolution operation.
  • dilation (tuple[int, int]) – The spacing between the kernel points.
  • padding (tuple[int, int, int, int]) – The amount of padding applied to the input.
  • groups (int) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups.
  • bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Optional 1-D bias of shape (out_channels,).
  • input_layout (ConvInputLayout) – Layout of the input tensor (default NHWC).
  • filter_layout (FilterLayout) – Layout of the filter tensor (default RSCF).

Returns:

A symbolic tensor value with the convolution applied.

Return type:

TensorValue

conv2d_transpose()

max.graph.ops.conv2d_transpose(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), output_paddings=(0, 0), bias=None, input_layout=ConvInputLayout.NHWC, filter_layout=FilterLayout.RSCF)

source

Computes the 2-D deconvolution of the input with the given filter, strides, dilations, paddings, and groups.

The op supports the transpose (gradient) of convolution, with the following layout assumptions: (note the out_channel is w.r.t. the original convolution)

  • input x has NHWC layout, i.e., (batch_size, height, width, in_channels)
  • filter has layout RSCF, i.e., (kernel_height, kernel_width, out_channels, in_channels)
  • bias has shape (out_channels,)

The padding values are expected to take the form in the form [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]].

This op effectively computes the gradient of a convolution with respect to its input (as if the original convolution operation had the same filter and hyperparameters as this op). A visualization of the computation can be found in https://d2l.ai/chapter_computer-vision/transposed-conv.html.

The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2D ConvTranspose, dim1 here represents H_out and dim2 represents W_out. In python like syntax, padding a 2x4 spatial output with [0, 1, 2, 1] would yield:

output = [
  [1, 2, 3, 4],
  [5, 6, 7, 8]
]
# Shape is 2x4

padded_input = [
  [3],
]
# Shape is 1x1

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An NHWC input tensor to perform the deconvolution upon.
  • filter (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The convolution filter in RSCF layout: (height, width, out_channels, in_channels).
  • stride (tuple[int, int]) – The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 0.
  • dilation (tuple[int, int]) – The spacing between the kernel points.
  • padding (tuple[int, int, int, int]) – The amount of padding applied to the input.
  • output_paddings (tuple[int, int]) – this argument is meant to resolve the ambiguity of multiple potential output shapes when any stride is greater than 1. Basically, we’ll add output_paddings[i] number of zeros at the end of output’s ith axis. We only support output_paddings = 0.
  • bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Tensor of shape (out_channels,).
  • input_layout (ConvInputLayout) – Layout of the input tensor (default NHWC).
  • filter_layout (FilterLayout) – Layout of the filter tensor (default RSCF).

Returns:

A symbolic tensor value with the convolution applied.

Return type:

TensorValue

conv3d()

max.graph.ops.conv3d(x, filter, stride=(1, 1, 1), dilation=(1, 1, 1), padding=(0, 0, 0, 0, 0, 0), groups=1, bias=None, input_layout=ConvInputLayout.NHWC, filter_layout=FilterLayout.QRSCF)

source

Computes the 3-D convolution product of the input with the given filter, strides, dilations, paddings, and groups.

The op supports 3-D convolution, with the following layout assumptions:

  • input has NDHWC layout, i.e., (batch_size, depth, height, width, in_channels)
  • filter has layout RSCF, i.e., (depth, height, width, in_channels / num_groups, out_channels)

The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 3-D convolution, dim1 here represents D, dim2 represents H and dim3 represents W. In Python like syntax, padding a 2x3 spatial input with [0, 1, 2, 1] would yield:

input = [
  [1, 2, 3],
  [4, 5, 6]
]
# Shape is 2x3

padded_input = [
  [0, 0, 1, 2, 3, 0],
  [0, 0, 4, 5, 6, 0],
  [0, 0, 0, 0, 0, 0]
]
# Shape is 3x6

This op currently only supports strides and padding on the input.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An NDHWC input tensor to perform the convolution upon.
  • filter (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The convolution filter in RSCF layout: (depth, height, width, in_channels / num_groups, out_channels).
  • stride (tuple[int, int, int]) – The stride of the convolution operation.
  • dilation (tuple[int, int, int]) – The spacing between the kernel points.
  • padding (tuple[int, int, int, int, int, int]) – The amount of padding applied to the input.
  • groups (int) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups.
  • bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Optional 1-D bias of shape (out_channels,).
  • input_layout (ConvInputLayout) – Layout of the input tensor (default NDHWC).
  • filter_layout (FilterLayout) – Layout of the filter tensor (default QRSCF).

Returns:

A symbolic tensor value with the convolution applied. Output shape = (batch_size, depth, height, width, out_channels).

Return type:

TensorValue

cos()

max.graph.ops.cos(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

cumsum()

max.graph.ops.cumsum(x, axis=-1, exclusive=False, reverse=False)

source

Computes the cumulative sum of the input tensor along the given axis.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to sum over.
  • axis (int) – The axis along which to compute the sum. If negative, indexes from the last dimension. For example, a value of -1 will compute the sum along the last dimension.
  • exclusive (bool) – If set, start at 0 and exclude the final element. Otherwise, start with the first element. Said another way, cumsum computes [sum(x[…, :i, …]) for i in range(x.shape[axis])]. If exclusive is set, the bounds are instead range(1, x.shape[axis]).
  • reverse (bool) – If set, start from the end. In other words, the first element will be the total sum, with each element following counting downwards; or [sum(x[…, i:, …]) for i in range(x.shape[axis])].

Returns:

A symbolic tensor representing the result of the cumsum operation. The tensor will have the same type as the input tensor. The computed values will be the cumulative sum of the values along the given axis, according to the specified parameters:

  • if exclusive is set, the first value will be 0, and the last value will be excluded from the sum
  • if reverse is set, the sum will be computed starting at the back of the axis back to the front, rather than front-to-back

Raises:

ValueError – If x is on a non-CPU device and strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

custom()

max.graph.ops.custom(name, device, values, out_types, parameters=None)

source

Creates a node to execute a custom graph operation in the graph.

The custom op should be registered by annotating a function with the @compiler.register decorator.

Parameters:

  • name (str) – The op name provided to @compiler.register.
  • values (Sequence[Value[Any]]) – The op function’s arguments.
  • out_types (Sequence[Type[Any]]) – The list of op function’s return type.
  • parameters (Mapping[str, bool | int | str | DType] | None) – Dictionary of extra parameters expected by the kernel.
  • device (Device | DeviceRef) – Device that the op is assigned to. This becomes a target parameter to the kernel.

Returns:

Symbolic values representing the outputs of the op in the graph. These correspond 1:1 with the types passed as out_types.

Return type:

list[Value[Any]]

dequantize()

max.graph.ops.dequantize(encoding, quantized)

source

Dequantizes a quantized tensor to floating point.

NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.

Parameters:

Returns:

The dequantized result (a floating point tensor).

Return type:

TensorValue

distributed_broadcast()

max.graph.ops.distributed_broadcast(input, signal_buffers)

source

Broadcast tensor from source GPU to all GPUs.

This op is a collective operation which broadcasts a tensor from the source GPU (where the input tensor resides) to all participating GPUs. Each GPU receives a copy of the input tensor.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Input tensor to broadcast. The device where this tensor resides becomes the root/source of the broadcast.
  • signal_buffers (Iterable[BufferValue | HasBufferValue]) – Device buffer values used for synchronization. The number of signal buffers determines the number of participating GPUs.

Returns:

List of output tensors, one per device. Each output tensor has the same shape and dtype as the input tensor.

Raises:

ValueError – If signal_buffers is empty, if input tensor device is not found in signal buffer devices, or if devices are not unique.

Return type:

list[TensorValue]

distributed_scatter()

max.graph.ops.distributed_scatter(input_chunks, signal_buffers)

source

Scatter different chunks from root GPU to device groups.

Each DP replica group receives a different input chunk. All TP devices within the same replica get the same chunk. Uses a pull-based approach where each GPU reads its chunk from the root GPU via P2P.

Parameters:

  • input_chunks (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – Input tensors to scatter, one per DP replica. All must reside on the same root device. The number of chunks determines dp_size.
  • signal_buffers (Iterable[BufferValue | HasBufferValue]) – Device buffer values used for synchronization. The number of signal buffers determines the number of participating GPUs (ngpus).

Returns:

List of output tensors, one per device. Each output tensor has the same shape and dtype as its replica’s input chunk.

Raises:

ValueError – If fewer than 2 signal buffers, if input chunks are not on the same device, or if devices are not unique.

Return type:

list[TensorValue]

div()

max.graph.ops.div(lhs, rhs)

source

Divides two symbolic tensors using true division (Python operator /).

For integer operands, this performs true division by promoting to float, matching Python’s / operator behavior. For floating-point operands, this performs standard floating-point division.

Creates a new op node to compute the division of two symbol tensor values and adds it to the graph, returning the symbolic result.

Parameters:

Returns:

A symbolic tensor value representing the output of the division. The result will have: : - floating-point dtype for integer operands, promoted dtype for mixed types

  • the same shape as the broadcast of the two input shapes.

Raises:

  • Error – If the input values’ shapes are not compatible for broadcasting.
  • Error – If one of the input values has an unsupported dtype.
  • Error – If the two symbols are parts of different graphs.

Return type:

TensorValue

equal()

max.graph.ops.equal(lhs, rhs)

source

Parameters:

Return type:

TensorValue

erf()

max.graph.ops.erf(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

exp()

max.graph.ops.exp(x)

source

Computes the elementwise exp (exponential) function of a symbolic tensor.

Creates a new op node to compute the elementwise exponential function of a symbolic tensor and adds it to the graph, returning the symbolic result. The exp function is fundamental in neural networks, used in attention mechanisms, activation functions, and probability distributions.

import max.experimental.functional as F
from max.experimental.tensor import Tensor

# Create input tensor
x = Tensor.constant([0.0, 1.0, 2.0])

# Compute exponential
result = F.exp(x)
print(result)
# Output: [1.0, 2.718..., 7.389...]
# (e^0 = 1, e^1 ≈ 2.718, e^2 ≈ 7.389)

exp is defined as exp(x) = e^x, where e is Euler’s number.

Parameters:

Returns:

A new symbolic tensor value representing the output of the exp value computation.

Raises:

Error – If the symbol doesn’t represent a tensor value.

Return type:

TensorValue

flatten()

max.graph.ops.flatten(x, start_dim=0, end_dim=-1)

source

Flattens the specified dims of a symbolic tensor.

The number and order of the elements in the tensor is unchanged. All dimensions from start_dim to end_dim (inclusive) are merged into a single output dim.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to flatten.
  • start_dim (int) – The first dimension to flatten. Supports negative indexing. Defaults to 0.
  • end_dim (int) – The last dimension to flatten (inclusive). Supports negative indexing. Defaults to -1.

Returns:

A symbolic tensor with the same elements as the input, but with dimensions start_dim through end_dim merged into one.

Raises:

  • IndexError – If start_dim or end_dim are out of range.
  • ValueError – If start_dim comes after end_dim.

Return type:

TensorValue

floor()

max.graph.ops.floor(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

fold()

max.graph.ops.fold(input, output_size, kernel_size, stride=1, dilation=1, padding=0)

source

Combines an array of sliding blocks into a larger containing tensor.

The input tensor must have shape (N, C * kernel_sizes, L) where N is the batch dimension, C is the number of channels, kernel_sizes is the product of the kernel sizes, and L is the number of local blocks.

The resulting output tensor will have shape (N, C, output_shape[0], output_shape[1]).

L, the number of blocks, must be equivalent to: prod((output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1)

where d is over all spatial dimensions.

Parameters:

Returns:

The folded 4D tensor with shape (N, C, output_shape[0], output_shape[1]).

Return type:

TensorValue

gather()

max.graph.ops.gather(input, indices, axis)

source

Selects elements out of an input tensor by index.

Parameters:

Returns:

A new symbolic tensor representing the result of the gather operation.

Return type:

TensorValue

gather_nd()

max.graph.ops.gather_nd(input, indices, batch_dims=0)

source

Selects elements out of an input tensor by N-dimensional index.

This operation performs N-dimensional indexing into input using indices. Unlike gather(), which indexes along a single axis, gather_nd() allows indexing along multiple dimensions simultaneously.

input_shape = ["a", "b", "c", "d", "e"]
indices_shape = ["a", "f", 3]
input_type = TensorType(DType.bfloat16, input_shape)
indices_type = TensorType(DType.int32, indices_shape)
with Graph("gather_nd", input_types=[input_type, indices_type]) as graph:
    input, indices = graph.inputs
    gathered = ops.gather_nd(input, indices, batch_dims=1)
    print(gathered.type)
# Output: TensorType(dtype=DType.bfloat16, shape=["a", "f", "e"])

In this example:

  • batch_dims is 1, so there’s 1 shared dimension at the beginning.
  • indices has an additional dimension “f” which becomes part of the output.
  • The last dimension of indices is the index vector; values in this vector are interpreted to be indices into “b”, “c”, and “d”.
  • Since batch_dims (1) + index size (3) < input.rank (5), the remaining dimensions (in this case “e”) are sliced into the output as features.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to select elements from.
  • indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of index values to use for selection. The last dimension of this tensor must be static. This dimension will be used to index or slice into input immediately following batch_dims initial dimensions. The size of this index dimension is the number of dimensions it specifies.
  • batch_dims (int) – The number of leading batch dimensions shared by input and indices; 0 by default. input and indices must exactly match up to their first batch_dims dimensions. This function does not broadcast.

Returns:

A new symbolic tensor representing the result of the gather operation. The output will have the same dtype as input, and will have shape depending on the inputs, in this order:

  • input.shape[:batch_dims] – The “broadcast” dimensions (though note that this function does not broadcast). These dimensions must be identical between input and indices.
  • indices.shape[batch_dims:-1] – The “gather” dimensions; this allows multi-dimensional tensors of indices. The last dimension is the index vector.
  • input.shape[batch_dims + indices.shape[-1]:] – The “slice” dimensions. If batch_dims < input.rank - indices.shape[-1] (again, this last is the index vector), then any following dimensions of the inputs are taken entirely as though slicing.

Return type:

TensorValue

gelu()

max.graph.ops.gelu(x, approximate='none')

source

Computes the elementwise gelu of a symbolic tensor.

Creates a new op node to compute the elementwise gelu of a symbolic tensor and adds it to the graph, returning the symbolic result.

For approximate == "none", the exact gelu function is computed.

For approximate == "tanh", the approximation:

gelu(x)=0.5x(1.0+tanh(0.7978845608028654(x+0.044715x3)))gelu(x) = 0.5 * x * (1.0 + tanh(0.7978845608028654 * (x + 0.044715 * x**3)))

is used.

For approximate == "quick", the approximation:

gelu(x)=sigmoid(1.702x)xgelu(x) = sigmoid(1.702 * x) * x

is used.

Parameters:

  • x (TensorValue) – The symbolic tensor to use as the input to the gelu computation.
  • approximate (str) – One of none, tanh, or quick.

Returns:

A new symbolic tensor value representing the output of the gelu computation.

Raises:

  • Error – If the symbol doesn’t represent a tensor value.
  • ValueError – If the approximation method is invalid.

greater()

max.graph.ops.greater(lhs, rhs)

source

Parameters:

Return type:

TensorValue

greater_equal()

max.graph.ops.greater_equal(lhs, rhs)

source

Parameters:

Return type:

TensorValue

group_norm()

max.graph.ops.group_norm(input, gamma, beta, num_groups, epsilon)

source

Performs group normalization.

Divides channels into groups and computes normalization statistics within each group. Useful for small batch sizes where batch normalization is unstable.

Parameters:

Returns:

A normalized tensor with the same shape as input.

Raises:

ValueError – If the input tensor has fewer than 2 dimensions.

Return type:

TensorValue

hann_window()

max.graph.ops.hann_window(window_length, device, periodic=True, dtype=float32)

source

Calculate a Hann window for a given length.

Hann window function:

H[n]=1/2[1cos(2pin/(N1))]H[n] = 1/2 [1 - cos(2 * pi * n / (N - 1))]

where N is window_length.

Parameters:

  • window_length (int) – The length of the window.
  • device (DeviceRef) – The device to run the operation on.
  • periodic (bool) – bool flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions like stft(). hann_window(L, periodic=True) == hann_window(L + 1, periodic=False)[:-1])
  • dtype (DType) – The desired data type of the output tensor.

Returns:

A 1-D tensor of size (window_length,) containing the window.

Raises:

  • ValueError – If window_length is negative.
  • TypeError – If window_length is not an integer.

Return type:

TensorValue

inplace_custom()

max.graph.ops.inplace_custom(name, device, values, out_types=None, parameters=None)

source

Creates a node to execute an in-place custom graph operation in the graph.

The custom op should be registered by annotating a function with the @compiler.register decorator.

Parameters:

  • name (str) – The op name provided to @compiler.register.
  • device (Device | DeviceRef) – Device that the op is assigned to. This becomes a target parameter to the kernel.
  • values (Sequence[Value[Any]]) – The op function’s arguments.
  • out_types (Sequence[Type[Any]] | None) – Optional sequence of output types for the op.
  • parameters (dict[str, bool | int | str | DType] | None) – Dictionary of extra parameters expected by the kernel.

Return type:

list[Value[Any]]

irfft()

max.graph.ops.irfft(input_tensor, n=None, axis=-1, normalization=Normalization.BACKWARD, input_is_complex=False, buffer_size_mb=512)

source

Compute the inverse real FFT of the input tensor.

Parameters:

  • input_tensor (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – The input tensor to compute the inverse real FFT of.
  • n (int | None) – The size of the output tensor. Must be an int, and cannot be a symbolic Buffer. The input tensor will be padded or truncated to n // 2 + 1 along the specified axis.
  • axis (int) – The axis to compute the inverse real FFT of.
  • normalization (Normalization | str) – The normalization to apply to the output tensor. Can be “backward”, “ortho”, or “forward”. When “backward”, the output is divided by n. When “ortho”, the output is divided by sqrt(n). When “forward”, no normalization is applied.
  • input_is_complex (bool) – Whether the input tensor is already interleaved complex. The last dimension of the input tensor must be 2, and is excluded from the dimension referred to by axis.
  • buffer_size_mb (int) – The estimated size of a persistent buffer to use for storage of intermediate results. Needs to be the same across multiple calls to irfft within the same graph. Otherwise, multiple buffers will be allocated.

Returns:

The inverse real FFT of the input tensor. The shape of the output tensor is the same as the shape of the input tensor, except for the axis that the inverse real FFT is computed over, which is replaced by n.

is_inf()

max.graph.ops.is_inf(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

is_nan()

max.graph.ops.is_nan(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

layer_norm()

max.graph.ops.layer_norm(input, gamma, beta, epsilon)

source

Performs layer normalization.

Parameters:

Returns:

A graph tensor value with the normalization applied.

Raises:

  • ValueError – If gamma size doesn’t match the last dimension of input.
  • ValueError – If beta size doesn’t match the last dimension of input.
  • ValueError – If epsilon is not positive.

Return type:

TensorValue

log()

max.graph.ops.log(x)

source

Computes the elementwise natural logarithm of a symbolic tensor.

Creates a new op node to compute the elementwise natural logarithm of a symbolic tensor and adds it to the graph, returning the symbolic result. The natural logarithm is used in loss functions, normalization, and probability calculations in machine learning.

import max.experimental.functional as F
from max.experimental.tensor import Tensor

# Create input tensor (positive values only)
x = Tensor.constant([1.0, 2.718, 7.389, 20.0])

# Compute natural logarithm
result = F.log(x)
print(result)
# Output: [0.0, 1.0, 2.0, 2.996...]
# (log(1) = 0, log(e) = 1, log(e^2) = 2)

The natural logarithm function log is defined as the inverse of the exponential function exp(). In other words, it computes the value y in the equation x = e^y where e is Euler’s number.

log(x) is undefined for x <= 0 for real numbers. Complex numbers are currently unsupported.

Parameters:

Returns:

A new symbolic tensor value representing the output of the natural logarithm value computation.

Raises:

Error – If the symbol doesn’t represent a tensor value.

Return type:

TensorValue

log1p()

max.graph.ops.log1p(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

logical_and()

max.graph.ops.logical_and(lhs, rhs)

source

Parameters:

Return type:

TensorValue

logical_not()

max.graph.ops.logical_not(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

logical_or()

max.graph.ops.logical_or(lhs, rhs)

source

Parameters:

Return type:

TensorValue

logical_xor()

max.graph.ops.logical_xor(lhs, rhs)

source

Parameters:

Return type:

TensorValue

logsoftmax()

max.graph.ops.logsoftmax(value, axis=-1)

source

Parameters:

Return type:

TensorValue

masked_scatter()

max.graph.ops.masked_scatter(input, mask, updates, out_dim)

source

Creates a new symbolic tensor where the updates are written to input where mask is true.

Parameters:

Returns:

A new symbolic tensor representing the result of the masked_scatter operation.

Return type:

TensorValue

matmul()

max.graph.ops.matmul(lhs, rhs)

source

Computes the matrix multiplication of two tensor graph values.

Performs general matrix multiplication with broadcasting. Matrix multiplication is fundamental to neural networks, used for linear transformations, attention mechanisms, and fully connected layers.

from max.experimental.tensor import Tensor

# Create two 2x2 matrices
x = Tensor.constant([[1.0, 2.0], [3.0, 4.0]])  # Shape: (2, 2)
w = Tensor.constant([[5.0, 6.0], [7.0, 8.0]])  # Shape: (2, 2)

# Matrix multiply using @ operator (uses matmul internally)
result = x @ w
print("Matrix multiplication result:")
print(result)
# Output: [[19.0, 22.0],
#          [43.0, 50.0]]
# Computed as: result[i,j] = sum(x[i,k] * w[k,j])

# Can also call directly via functional API
import max.experimental.functional as F
result2 = F.matmul(x, w)
# Same result as x @ w

If the lhs is 1D, it will be reshaped to 1xD. If the rhs is 1D, it will be reshaped to Dx1. In both cases, the additional 1 dimensions will be removed from the output shape.

For the multiplication, the innermost (rightmost) 2 dimensions are treated as a matrix. The lhs matrix will have the shape MxK. The rhs matrix will have the shape KxN. The output will have the shape MxN The K dimensions must be equivalent in both matrices.

The remaining outer dimensions will be broadcasted.

Parameters:

Returns:

A tensor graph value representing the matrix product of lhs and rhs. For 2D inputs, the output shape is (M, N) where lhs is (M, K) and rhs is (K, N). For higher-dimensional inputs, batch dimensions are preserved and the operation is applied to the last two dimensions of each input.

Return type:

TensorValue

max()

max.graph.ops.max(x, y=None, /, axis=None)

source

Overload for ops.elementwise.max and ops.reduction.max.

  • If two tensors are provided, axis is ignored and returns an elementwise maximum.
  • If one tensor is provided, compute ops.reduction.max on the tensor and axis.

Parameters:

Return type:

TensorValue

max_pool2d()

max.graph.ops.max_pool2d(input, kernel_size, stride=1, dilation=1, padding=0, ceil_mode=False)

source

Perform a 2D max pooling operation on the input tensor.

Applies a 2D max pooling operation to the input tensor with layout [N, H, W, C]. The pooling operation slides a window of size kernel_size over the spatial dimensions and selects the maximum value within each window.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape [N, H, W, C].
  • kernel_size (tuple[int | str | Dim | integer[Any] | TypedAttr, int | str | Dim | integer[Any] | TypedAttr]) – The height and width of the sliding window.
  • stride (int | tuple[int, int]) – The stride of the sliding window. Can be a single integer applied to both spatial dimensions or a tuple (stride_h, stride_w). Defaults to 1.
  • dilation (int | tuple[int, int]) – The spacing between kernel elements. Can be a single integer or a tuple (dilation_h, dilation_w). Defaults to 1.
  • padding (int | tuple[int, int]) – Zero-padding added to both sides of each spatial dimension. Can be a single integer or a tuple (pad_h, pad_w). Defaults to 0.
  • ceil_mode (bool) – If True, uses ceil instead of floor when computing the output spatial shape. Defaults to False.

Returns:

A symbolic tensor with the max pooling applied, with shape [N, H_out, W_out, C].

Return type:

TensorValue

mean()

max.graph.ops.mean(x, axis=-1)

source

Reduces a symbolic tensor using a mean operation.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor for the operation.
  • axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.

Returns:

A symbolic tensor representing the result of the mean operation. The tensor will have the same rank as the input tensor, and the same shape except along the axis dimension which will have size 1.

Return type:

TensorValue

min()

max.graph.ops.min(x, y=None, /, axis=None)

source

Overload for ops.elementwise.min and ops.reduction.min.

  • If two tensors are provided, axis is ignored and returns an elementwise minimum.
  • If one tensor is provided, compute ops.reduction.min on the tensor and axis.

Parameters:

Return type:

TensorValue

mod()

max.graph.ops.mod(lhs, rhs)

source

Parameters:

Return type:

TensorValue

mul()

max.graph.ops.mul(lhs, rhs)

source

Parameters:

Return type:

TensorValue

negate()

max.graph.ops.negate(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

nonzero()

max.graph.ops.nonzero(x, out_dim)

source

Returns the indices of all nozero elements in a tensor.

Returns a tensor of indices of the nonzero values in the given tensor. The return value is a 2D tensor of shape [out_dim x rank_in], where out_dim is the number of nonzero elements in the input tensor, and rank_in is the rank of the input tensor. Indices are generated in row-major order.

Parameters:

Returns:

A symbolic tensor of indices

Raises:

ValueError – If x is scalar, or if x is on a non-CPU device and strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

not_equal()

max.graph.ops.not_equal(lhs, rhs)

source

Parameters:

Return type:

TensorValue

outer()

max.graph.ops.outer(lhs, rhs)

source

Computes the outer product of two symbolic vectors.

Parameters:

Returns:

A symbolic tensor representing the outer product of the two input vectors. It will have rank 2, with the dimension sizes being the number of elements of lhs and rhs respectively.

Return type:

TensorValue

pad()

max.graph.ops.pad(input, paddings, mode='constant', value=0)

source

Pads a tensor along every dimension.

Adds padding to the input tensor using the specified padding values and mode.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to pad.

  • paddings (Iterable[int]) – Sequence of padding values. For a tensor with rank N, paddings must contain 2*N non-negative integers in the order [pad_before_dim0, pad_after_dim0, pad_before_dim1, pad_after_dim1, ...].

  • mode (Literal['constant', 'reflect', 'edge']) –

    The padding mode. Supported values:

    • "constant" - fill padded cells with value.
    • "reflect" - reflect values about the content-region edges (excludes the boundary element, equivalent to numpy.pad with mode='reflect').
    • "edge" - repeat the nearest boundary element (equivalent to numpy.pad with mode='edge').
  • value (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The constant fill value (only used when mode='constant'). Defaults to 0.

Returns:

A symbolic tensor with the same dtype as input, padded along each dimension according to paddings.

Raises:

ValueError – If mode is not one of the supported values, or if any padding value is negative.

Return type:

TensorValue

parallel()

max.graph.ops.parallel(inputs, body_fn, *, extra_inputs=None, chain=None)

source

Execute a function in parallel for each input via mo.parallel.

The body function receives a representative TensorValue (typed like the first input) and should return one TensorValue. The runtime dispatches the body once per input, substituting the actual per-launch tensor.

When extra_inputs are provided (e.g. signal buffers for bundled collectives), the parallel op uses tupled syntax and the body function receives an additional BufferValue argument per extra-input group.

When chain is provided, the parallel region is sequenced relative to prior ops and the returned out_chain represents completion of all parallel launches. A chain is required when extra_inputs contain buffers that need ordering guarantees.

Examples:

# Simple elementwise (no chain needed):
results = ops.parallel([gpu0, gpu1], lambda x: ops.relu(x))

# Bundled allreduce with chain:
results, out_chain = ops.parallel(
    tensors, body_fn, extra_inputs=signal_bufs, chain=in_chain
)

Parameters:

  • inputs (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – Tensors to dispatch over. All must share the same shape and dtype; device labels must match (IDs may differ).
  • body_fn (Callable[[...], TensorValue | Iterable[TensorValue]]) – Callable that takes one TensorValue (and optionally one BufferValue per extra-input group) and returns one TensorValue result.
  • extra_inputs (Iterable[BufferValue | HasBufferValue] | None) – Optional per-device buffer values (e.g. signal buffers). When provided, must have the same length as inputs.
  • chain (_ChainValue | None) – Optional chain value for sequencing. Required when extra_inputs contain buffers. Typically obtained from graph._merge_chains(...).

Returns:

(tensor_results, out_chain). When chain is omitted: tensor_results.

Return type:

When chain is provided

permute()

max.graph.ops.permute(x, dims)

source

Permutes all dimensions of a symbolic tensor.

Parameters:

Returns:

A new symbolic tensor with the dimensions permuted to match the passed in order. It has the same elements and dtype, but the order of the elements is different according to the permutation.

Return type:

TensorValue

pow()

max.graph.ops.pow(lhs, rhs)

source

Parameters:

Return type:

TensorValue

print()

max.graph.ops.print(value, label='debug_tensor')

source

Prints the value of a tensor or a string during graph execution.

This function is used to output the current value of a tensor and is primarily used for debugging purposes within the context of the Max Engine and its graph execution framework. This is particularly useful to verify the intermediate results of your computations are as expected.

By printing the tensor values, you can visualize the data flowing through the graph, which helps in understanding how the operations are transforming the data.

When labeling the function you can assign the output, making it easier to identify which tensor’s value is being printed, especially when there are multiple print statements in a complex graph.

def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, Any]:
    input_type = TensorType(dtype=DType.float32, shape=(1,), device=DeviceRef.CPU())
    with Graph(
        "simple_add_graph", input_types=(input_type, input_type)
    ) as graph:
        lhs, rhs = graph.inputs
        out = ops.add(lhs, rhs)
        ops.print(out, label="addition_output")  # Pass the output tensor here

        graph.output(out)
        print("final graph:", graph)

Parameters:

  • value (str | TensorValue) – The value to print. Can be either a string or a TensorValue.
  • label (str) – A label to identify the printed value. Defaults to debug_tensor.

Return type:

None

prod()

max.graph.ops.prod(x, axis=-1)

source

Reduces a symbolic tensor using a product operation.

Computes the product of elements along a specified axis.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor for the operation.
  • axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.

Returns:

A symbolic tensor representing the result of the product operation. The tensor will have the same rank as the input tensor, and the same shape except along the axis dimension which will have size 1.

Return type:

TensorValue

qmatmul()

max.graph.ops.qmatmul(encoding, config, lhs, *rhs)

source

Performs matrix multiplication between floating point and quantized tensors.

This quantizes the lhs floating point value to match the encoding of the rhs quantized value, performs matmul, and then dequantizes the result. Beware that, compared to a regular matmul op, this one expects the rhs value to be transposed. For example, if the lhs shape is [32, 64], and the quantized rhs shape is also [32, 64], then the output shape is [32, 32].

That is, this function returns the result from:

dequantize(quantize(lhs) @ transpose(rhs))

The last two dimensions in lhs are treated as matrices and multiplied by rhs (which must be a 2D tensor). Any remaining dimensions in lhs are broadcast dimensions.

NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.

Parameters:

  • encoding (QuantizationEncoding) – The quantization encoding to use.
  • config (QuantizationConfig | None) – Optional quantization config; required for some encodings (for example, GPTQ).
  • lhs (TensorValue) – The non-quantized, left-hand-side of the matmul.
  • rhs (TensorValue) – The transposed and quantized right-hand-side tensor(s).

Returns:

The dequantized result (a floating point tensor).

Return type:

TensorValue

range()

max.graph.ops.range(start, stop, step=1, out_dim=None, *, dtype, device)

source

Creates a sequence of numbers from start to stop (exclusive) with step.

All arguments are mandatory and must have the same element type.

Note the following restrictions on input values:

  1. step must be non-zero.
  2. stop - start must be zero or have the same sign as step.

Parameters:

Returns:

A symbolic tensor value containing the defined range of values.

Return type:

TensorValue

rebind()

max.graph.ops.rebind(x, shape, message='', layout=None)

source

Rebinds a symbolic tensor to a specified set of dimensions.

This does not mutate the symbolic tensor passed in, but instead adds a runtime assert that the input symbolic shape is equivalent to out_dims shape. For example, if the input tensor shape has dynamic/unknown sizes, this will assert a fixed sizes that may be required for a subsequent operation.

Parameters:

Returns:

A symbolic tensor with the same elements and shape as the given tensor, but with the symbolic shape asserted to out_dims.

Return type:

TensorValue

relu()

max.graph.ops.relu(x)

source

Computes the elementwise ReLU (Rectified Linear Unit) of a symbolic tensor.

Creates a new op node to compute the elementwise ReLU of a symbolic tensor and adds it to the graph, returning the symbolic result. ReLU is defined as relu(x) = max(0, x), setting all negative values to zero while leaving positive values unchanged.

ReLU is one of the most common activation functions in neural networks due to its computational efficiency and effectiveness in addressing the vanishing gradient problem.

import max.experimental.functional as F
from max.experimental.tensor import Tensor

# Create input with negative and positive values
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])

# Apply ReLU activation
result = F.relu(x)
print(result)
# Output: [[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]
# Negative values become 0, positive values unchanged

Parameters:

Returns:

A new symbolic tensor value representing the output of the relu value computation.

Raises:

Error – If the symbol doesn’t represent a tensor value.

Return type:

TensorValue

repeat_interleave()

max.graph.ops.repeat_interleave(x, repeats, axis=None, out_dim=None)

source

Repeats elements of a tensor along the given dimension.

Modeled after torch.repeat_interleave, with the constraint that

For example, given repeats=2 and the following input:

# Input tensor with shape (2, 2)
input = TensorValue(x)  # Contains [[1.0, 2.0], [3.0, 4.0]]

repeat_interleave with axis=0:

# Output tensor with shape (4, 2)
output = repeat_interleave(input, repeats=2, axis=0)
# Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]

repeat_interleave with axis=1:

# Output tensor with shape (2, 4)
output = repeat_interleave(input, repeats=2, axis=1)
# Contains [[1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0]]

repeat_interleave with axis=None (the default):

repeat_interleave with repeats=[2, 3] and axis=0:

repeat_value = TensorValue([2, 3])

# Output tensor with shape (5, 2)
output = repeat_interleave(input, repeats=repeat_value, axis=0)
# Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
# Output tensor with shape (8,)
output = repeat_interleave(input, repeats=2)  # axis = None
# Contains [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor.
  • repeats (int | TensorValue) – The number of repetitions for each element.
  • axis (int | None) – The dimension along which to repeat values. If axis is not specified or None (the default), flatten the input array and repeat the flattened values.
  • out_dim (int | str | Dim | integer[Any] | TypedAttr | None) – Optional symbolic dimension for the output size (for graph validation).

Returns:

A symbolic tensor with the elements interleaved.

Raises:

ValueError – If repeats non-positive or if axis is out of range.

Return type:

TensorValue

reshape()

max.graph.ops.reshape(x, shape)

source

Reshapes a symbolic tensor.

The number and order of the elements in the tensor is unchanged. In other words, if you were to iterate over elements in the tensor by major dimension to minor dimension, the iteration order would stay the same.

If a value of -1 is present in the shape, that dimension becomes an automatically calculated dimension collecting all unspecified dimensions. Its length becomes the number of elements in the original tensor divided by the product of elements of the reshape.

Parameters:

Returns:

A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as shape.

Raises:

ValueError – if input and target shapes’ number of elements mismatch.

Return type:

TensorValue

resize()

max.graph.ops.resize(input, shape, interpolation=InterpolationMode.BILINEAR)

source

Resize the input tensor to the given shape.

This function resizes a tensor using the specified interpolation method. The tensor is expected to have NCHW format (batch, channels, height, width).

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to resize. Must have rank 4 in NCHW format.
  • shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – Desired output shape of length 4 corresponding to (N, C, H, W).
  • interpolation (InterpolationMode) – Desired interpolation enum defined by InterpolationMode. Defaults to InterpolationMode.BILINEAR.

Returns:

A resized tensor with the shape specified by the shape argument.

Raises:

ValueError – If the input doesn’t have rank 4, shape has wrong number of elements, or unsupported interpolation mode is specified.

Return type:

TensorValue

resize_bicubic()

max.graph.ops.resize_bicubic(input, size)

source

Resize a tensor using bicubic interpolation.

Produces an output tensor whose dimensions are given by size using a 4x4-pixel Catmull-Rom (a=-0.75) cubic convolution filter with half_pixel coordinate mapping. Input must be rank-4 NCHW.

Parameters:

Returns:

A new symbolic tensor with shape size and the same dtype as input.

Raises:

ValueError – If input doesn’t have rank 4 or size has a different length.

Return type:

TensorValue

resize_linear()

max.graph.ops.resize_linear(input, size, coordinate_transform_mode=0, antialias=False)

source

Resize a tensor using linear (bilinear) interpolation.

Produces an output tensor whose spatial dimensions are given by size using separable 1-D linear filters. The operation maps output coordinates back to input coordinates according to coordinate_transform_mode.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to resize.

  • size (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – Desired output shape. Must have the same rank as input.

  • coordinate_transform_mode (int) –

    How to map an output coordinate to an input coordinate. Allowed values:

    • 0half_pixel (default): shifts by 0.5 before scaling, consistent with most deep-learning frameworks.
    • 1align_corners: aligns the corner pixels of input and output so that the first and last coordinates are preserved exactly.
    • 2asymmetric: no shift; equivalent to floor-dividing coordinates by the scale factor.
    • 3half_pixel_1D: like half_pixel but only applied to the last spatial dimension.
  • antialias (bool) – When True, applies an antialiasing filter when the output is smaller than the input (i.e. when downscaling), which reduces aliasing artifacts by widening the tent filter support by 1 / scale. Has no effect when upscaling.

Returns:

A new symbolic tensor with shape size and the same dtype as input.

Raises:

ValueError – If coordinate_transform_mode is not 0-3, or if size has a different rank than input.

Return type:

TensorValue

resize_nearest()

max.graph.ops.resize_nearest(input, size, coordinate_transform_mode=0, round_mode=0)

source

Resize a tensor using nearest-neighbor interpolation.

Produces an output tensor whose dimensions are given by size by selecting the nearest input sample for each output coordinate.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to resize.

  • size (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – Desired output shape. Must have the same rank as input.

  • coordinate_transform_mode (int) –

    How to map an output coordinate to an input coordinate. Allowed values:

    • 0half_pixel (default).
    • 1align_corners.
    • 2asymmetric.
    • 3half_pixel_1D.
  • round_mode (int) –

    How to round the mapped coordinate to select the nearest input sample. Allowed values:

    • 0HalfDown (default): ceil(x - 0.5).
    • 1HalfUp: floor(x + 0.5).
    • 2Floor: floor(x).
    • 3Ceil: ceil(x).

Returns:

A new symbolic tensor with shape size and the same dtype as input.

Raises:

ValueError – If coordinate_transform_mode is not 0-3, round_mode is not 0-3, or size has a different rank than input.

Return type:

TensorValue

rms_norm()

max.graph.ops.rms_norm(input, weight, epsilon, weight_offset=0.0, multiply_before_cast=False)

source

Performs Root Mean Square layer normalization.

Computes output = input / rms(input) * weight where rms(x) = sqrt(mean(x^2) + epsilon).

When multiply_before_cast is False (Llama-style), the input is cast to the output dtype before multiplication by the weight. When True (Gemma-style), the multiplication is performed before the cast.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to normalize.
  • weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The weight tensor whose shape must match the last dimension of input.
  • epsilon (float) – A small value added to the denominator for numerical stability.
  • weight_offset (float) – A value added to the weight before normalization. Typically 1 for Gemma-like normalization and 0 otherwise.
  • multiply_before_cast (bool) – Whether to multiply before casting to the output dtype.

Returns:

A normalized tensor with the same shape and dtype as input.

Raises:

ValueError – If weight shape doesn’t match the last dimension of input.

Return type:

TensorValue

round()

max.graph.ops.round(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

rsqrt()

max.graph.ops.rsqrt(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

scatter()

max.graph.ops.scatter(input, updates, indices, axis=-1)

source

Creates a new symbolic tensor where the updates are written to input according to indices.

Parameters:

Returns:

A new symbolic tensor representing the result of the scatter operation.

Raises:

ValueError – If axis is out of range, if dtypes mismatch, if indices dtype is not int32/int64, or if any input is on a non-CPU device and strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

scatter_add()

max.graph.ops.scatter_add(input, updates, indices, axis=-1)

source

Creates a new symbolic tensor by accumulating updates into input at indices.

Produces an output tensor by scattering elements from updates into input according to indices, summing values at duplicate indices. For a 2-D input with axis=0 the update rule is:

output[indices[i][j]][j] += updates[i][j]

and with axis=1:

output[i][indices[i][j]] += updates[i][j]

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Raises:

ValueError – If axis is out of range, if dtypes mismatch, if indices dtype is not int32/int64, or if any input is on a non-CPU device and strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

scatter_max()

max.graph.ops.scatter_max(input, updates, indices, axis=-1)

source

Creates a new symbolic tensor by scattering the maximum of updates into input.

Produces an output tensor by scattering elements from updates into input according to indices, keeping the maximum at duplicate indices. For a 2-D input with axis=0 the update rule is:

output[indices[i][j]][j] = max(output[indices[i][j]][j], updates[i][j])

and with axis=1:

output[i][indices[i][j]] = max(output[i][indices[i][j]], updates[i][j])

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Raises:

ValueError – If axis is out of range, if dtypes mismatch, if indices dtype is not int32/int64, or if any input is on a non-CPU device and strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

scatter_min()

max.graph.ops.scatter_min(input, updates, indices, axis=-1)

source

Creates a new symbolic tensor by scattering the minimum of updates into input.

Produces an output tensor by scattering elements from updates into input according to indices, keeping the minimum at duplicate indices. For a 2-D input with axis=0 the update rule is:

output[indices[i][j]][j] = min(output[indices[i][j]][j], updates[i][j])

and with axis=1:

output[i][indices[i][j]] = min(output[i][indices[i][j]], updates[i][j])

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Raises:

ValueError – If axis is out of range, if dtypes mismatch, if indices dtype is not int32/int64, or if any input is on a non-CPU device and strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

scatter_mul()

max.graph.ops.scatter_mul(input, updates, indices, axis=-1)

source

Creates a new symbolic tensor by scattering the product of updates into input.

Produces an output tensor by scattering elements from updates into input according to indices, multiplying values at duplicate indices. For a 2-D input with axis=0 the update rule is:

output[indices[i][j]][j] *= updates[i][j]

and with axis=1:

output[i][indices[i][j]] *= updates[i][j]

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Raises:

ValueError – If axis is out of range, if dtypes mismatch, if indices dtype is not int32/int64, or if any input is on a non-CPU device and strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

scatter_nd()

max.graph.ops.scatter_nd(input, updates, indices)

source

Creates a new symbolic tensor where the updates are scattered into input at specified indices.

Parameters:

Returns:

A new symbolic tensor representing the result of the scatter_nd operation.

Return type:

TensorValue

scatter_nd_add()

max.graph.ops.scatter_nd_add(input, updates, indices)

source

Creates a new symbolic tensor by accumulating updates into input at N-D indices.

Produces an output tensor by scattering slices from updates into a copy of input according to N-dimensional index vectors, summing values at duplicate index positions. Each index vector is the last dimension of indices and selects a slice (or scalar) in input.

Example for input.shape = [4, 2], indices.shape = [3, 1] (1-D partial indexing, writes whole rows):

output[indices[i, 0], :] += updates[i, :]

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Return type:

TensorValue

scatter_nd_max()

max.graph.ops.scatter_nd_max(input, updates, indices)

source

Creates a new symbolic tensor by scattering the maximum of updates into input at N-D indices.

Produces an output tensor by scattering slices from updates into a copy of input according to N-dimensional index vectors, keeping the maximum at duplicate index positions. Each index vector is the last dimension of indices and selects a slice (or scalar) in input.

Example for input.shape = [4, 2], indices.shape = [3, 1] (1-D partial indexing, writes whole rows):

output[indices[i, 0], :] = max(output[indices[i, 0], :], updates[i, :])

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Return type:

TensorValue

scatter_nd_min()

max.graph.ops.scatter_nd_min(input, updates, indices)

source

Creates a new symbolic tensor by scattering the minimum of updates into input at N-D indices.

Produces an output tensor by scattering slices from updates into a copy of input according to N-dimensional index vectors, keeping the minimum at duplicate index positions. Each index vector is the last dimension of indices and selects a slice (or scalar) in input.

Example for input.shape = [4, 2], indices.shape = [3, 1] (1-D partial indexing, writes whole rows):

output[indices[i, 0], :] = min(output[indices[i, 0], :], updates[i, :])

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Return type:

TensorValue

scatter_nd_mul()

max.graph.ops.scatter_nd_mul(input, updates, indices)

source

Creates a new symbolic tensor by scattering the product of updates into input at N-D indices.

Produces an output tensor by scattering slices from updates into a copy of input according to N-dimensional index vectors, multiplying values at duplicate index positions. Each index vector is the last dimension of indices and selects a slice (or scalar) in input.

Example for input.shape = [4, 2], indices.shape = [3, 1] (1-D partial indexing, writes whole rows):

output[indices[i, 0], :] *= updates[i, :]

Parameters:

Returns:

A new symbolic tensor with the same shape and dtype as input.

Return type:

TensorValue

shape_to_tensor()

max.graph.ops.shape_to_tensor(shape)

source

Converts a shape to a tensor.

This is useful for using a shape attribute in an op that expects a tensor value.

Parameters:

shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – the shape attribute of a tensor value.

Returns:

The TensorValue containing the same value as shape.

Return type:

TensorValue

Example:

>>> x = ops.constant(np.zeros((1,)), DType.int64, device=DeviceRef.CPU())
>>> result = ops.stack([
...     x,
...     ops.shape_to_tensor(x.shape),
... ])
TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)])

shard_and_stack()

max.graph.ops.shard_and_stack(inputs, devices, axis=0)

source

Shards a list of input tensors along a specified axis, producing multiple outputs.

This operation takes multiple input tensors, splits each along the specified axis into len(devices) chunks, and returns one output tensor per device. Each output contains the chunks at the corresponding index stacked from all inputs along a new dimension 0.

This is useful for distributing model weights across multiple devices in tensor parallel configurations.

For example, with 2 inputs A and B, axis=0, and 2 devices:

  • Input A shape [10, D], Input B shape [10, D]
  • Output 0: stack([A[0:5], B[0:5]]) -> shape [2, 5, D] on devices[0]
  • Output 1: stack([A[5:10], B[5:10]]) -> shape [2, 5, D] on devices[1]

With axis=1 and 2 devices:

  • Input A shape [D, 10], Input B shape [D, 10]
  • Output 0: stack([A[:, 0:5], B[:, 0:5]]) -> shape [2, D, 5] on devices[0]
  • Output 1: stack([A[:, 5:10], B[:, 5:10]]) -> shape [2, D, 5] on devices[1]

Parameters:

  • inputs (Sequence[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – A list of symbolic tensors to shard. All tensors must have the same shape, dtype, and device.
  • devices (Sequence[Device | DeviceRef]) – Target devices for each output tensor. The number of devices determines the number of splits. Each output tensor will be placed on the corresponding device. This enables direct host-to-device transfer without intermediate CPU storage.
  • axis (int) – The axis along which to split each input tensor. Defaults to 0. Supports negative indexing (for example, -1 for last axis).

Returns:

A list of len(devices) tensors, each with shape [num_inputs, D0, …, Daxis//len(devices), …, Dn-1] where the input shape is [D0, …, Daxis, …, Dn-1]. Output i contains the stacked chunks at position i from all input tensors, placed on devices[i].

Raises:

ValueError – If inputs list is empty, if devices list is empty, if input tensors don’t have matching shapes, if the dimension size at the axis is not evenly divisible by len(devices), or if axis is out of bounds.

Return type:

list[TensorValue]

sigmoid()

max.graph.ops.sigmoid(x)

source

Computes the elementwise sigmoid activation of a symbolic tensor.

Creates a new op node to compute the elementwise sigmoid of a symbolic tensor and adds it to the graph, returning the symbolic result. Sigmoid is defined as sigmoid(x) = 1 / (1 + exp(-x)), mapping all input values to the range (0, 1).

The sigmoid function is commonly used for binary classification tasks and as an activation function in neural networks, particularly in output layers for probability prediction.

import max.experimental.functional as F
from max.experimental.tensor import Tensor

# Create input tensor
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])

# Apply sigmoid activation
result = F.sigmoid(x)
print(result)
# Output: [[0.119, 0.269, 0.5], [0.731, 0.881, 0.953]]
# All values mapped to range (0, 1)

Parameters:

x (TensorValue) – The symbolic tensor to use as the input to the sigmoid computation.

Returns:

A new symbolic tensor value representing the output of the sigmoid computation.

Raises:

Error – If the symbol doesn’t represent a tensor value.

Return type:

TensorValue

silu()

max.graph.ops.silu(x)

source

Computes the elementwise silu of a symbolic tensor.

Creates a new op node to compute the elementwise silu of a symbolic tensor and adds it to the graph, returning the symbolic result.

silu is defined as silu(x) = x * sigmoid(x).

Parameters:

x (TensorValue) – The symbolic tensor to use as the input to the silu computation.

Returns:

A new symbolic tensor value representing the output of the silu computation.

Raises:

Error – If the symbol doesn’t represent a tensor value.

sin()

max.graph.ops.sin(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

slice_tensor()

max.graph.ops.slice_tensor(x, indices)

source

Slices out a subtensor view of the input tensor based on indices.

The semantics of slice_tensor() follow NumPy slicing semantics with the following restrictions:

  • Slice indices must not index out of [-dim - 1, dim - 1] for negative step, or [-dim, dim] for positive step.
# Reverse a tensor.
slice_tensor(x, [slice(None, None, -1)])
# Unsqueeze the second last dimension of a tensor.
slice_tensor(x, [..., None, slice(None)])

Returns:

The sliced subtensor of x.

Parameters:

Return type:

TensorValue

softmax()

max.graph.ops.softmax(value, axis=-1)

source

Parameters:

Return type:

TensorValue

split()

max.graph.ops.split(x, split_sizes, axis=0)

source

Splits the input tensor into multiple tensors along a given dimension.

Parameters:

Returns:

A list of tensors with the same length as split_sizes, where each tensor has the same shape as the input except along the split dimension axis, where the size is given by the corresponding element in split_sizes.

Return type:

list[TensorValue]

sqrt()

max.graph.ops.sqrt(x)

source

Computes the elementwise square root of a symbolic tensor.

Creates a new op node to compute the elementwise square root of a symbolic tensor and adds it to the graph, returning the symbolic result. Square root is commonly used in normalization operations, distance calculations, and implementing mathematical operations like standard deviation.

import max.experimental.functional as F
from max.experimental.tensor import Tensor

# Create tensor with positive values
x = Tensor.constant([1.0, 4.0, 9.0, 16.0])

# Compute square root
result = F.sqrt(x)
print(result)
# Output: [1.0, 2.0, 3.0, 4.0]

# Note: sqrt requires non-negative values
# For tensors with negative values, use abs first:
y = Tensor.constant([1.0, -4.0, 9.0, -16.0])
result2 = F.sqrt(F.abs(y))
print(result2)
# Output: [1.0, 2.0, 3.0, 4.0]

Parameters:

Returns:

A new symbolic tensor value representing the output of the sqrt value computation.

Raises:

Error – If the symbol doesn’t represent a tensor value.

Return type:

TensorValue

squeeze()

max.graph.ops.squeeze(x, axis)

source

Removes a size-1 dimension from a symbolic tensor.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to squeeze.
  • axis (int) – The dimension to remove from the input’s shape. If negative, this indexes from the end of the tensor. For example, squeeze(v, -1) squeezes the last dimension.

Returns:

A symbolic tensor with the same number of elements as the input tensor, and whose rank is 1 less than the rank of the input tensor.

Return type:

TensorValue

stack()

max.graph.ops.stack(values, axis=0)

source

Stacks a list of tensors along a new axis.

Parameters:

  • values (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – A list of symbolic tensor values. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension.
  • axis (int) – The axis to concatenate along. If negative, indexes relative to the end of the tensor shape plus 1. For instance, stack(vs, -1) will create and stack along a new axis as the last dimension, aad stack(vs, -2) will create and stack along a new dimension which is inserted immediately before the last dimension.

Returns:

A new symbolic tensor representing the result of the stack. It will have rank n+1 where n is the rank of each input tensor. Its size on each dimension other than axis will be the same as each input tensors’, with the new axis inserted. Along the new dimension it will have size len(values).

Return type:

TensorValue

sub()

max.graph.ops.sub(lhs, rhs)

source

Parameters:

Return type:

TensorValue

sum()

max.graph.ops.sum(x, axis=-1)

source

Reduces a symbolic tensor using a sum operation.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor for the operation.
  • axis (int) – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.

Returns:

A symbolic tensor representing the result of the sum operation. The tensor will have the same rank as the input tensor, and the same shape except along the axis dimension which will have size 1.

Return type:

TensorValue

tanh()

max.graph.ops.tanh(x)

source

Computes the elementwise tanh (hyperbolic tangent) of a symbolic tensor.

Creates a new op node to compute the elementwise tanh of a symbolic tensor and adds it to the graph, returning the symbolic result. Tanh is defined as tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)), mapping all input values to the range (-1, 1).

The tanh function is commonly used as an activation function in recurrent neural networks (RNNs) and as a hidden layer activation in feedforward networks. Unlike sigmoid which maps to (0, 1), tanh is zero-centered, which can help with gradient flow during training.

import max.experimental.functional as F
from max.experimental.tensor import Tensor

# Create input tensor
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])

# Apply tanh activation
result = F.tanh(x)
print(result)
# Output: [[-0.964, -0.762, 0.0], [0.762, 0.964, 0.995]]
# All values mapped to range (-1, 1)

Parameters:

Returns:

A new symbolic tensor value representing the output of the tanh value computation.

Raises:

Error – If the symbol doesn’t represent a tensor value.

Return type:

TensorValue

tile()

max.graph.ops.tile(x, repeats)

source

Returns a new tensor by tiling the input along each dimension.

The input is copied N_i times on the i-th dimension, where N_i = repeats[i]. The i-th dimension of the output shape is the i-th dimension of the input shape multiplied by N_i.

Parameters:

Returns:

A symbolic tensor whose i-th dimension size equals x.shape[i] * repeats[i].

Raises:

ValueError – If the length of repeats does not match the rank of x, or if any repeat value is not positive. Also raised for GPU inputs when strict_device_placement=DevicePlacementPolicy.Error.

Return type:

TensorValue

top_k()

max.graph.ops.top_k(input, k, axis=-1)

source

Returns tensor with only top K values along given axis.

Parameters:

Returns:

Top K values, Top K indices

Return type:

tuple[TensorValue, TensorValue]

transfer_to()

max.graph.ops.transfer_to(x, device)

source

Inserts a device transfer node into the compiled graph.

Moves x to device at execution time. This is a graph-level operation: it operates on symbolic TensorValue objects during graph tracing and is baked into the compiled graph as an mo.transfer MLIR op.

This is distinct from to(), which is a pre-compilation operation that moves stored weight tensors on the Python host before the graph is built. Use transfer_to when you need to route an activation tensor between devices inside forward() (for example, host-to-device input staging, device-to-host output retrieval, or cross-GPU tensor movement for multi-device models).

Host↔device transfers (CPU↔GPU) use the graph’s immutable root chain so they can be hoisted to model initialization by the optimizer. Device-to-device transfers (GPU↔GPU) join both per-device chains to prevent reordering that would deadlock multi-device collectives. If source and destination device are identical, this is a no-op.

Parameters:

Returns:

A new TensorValue on the specified device.

Return type:

TensorValue

transpose()

max.graph.ops.transpose(x, axis_1, axis_2)

source

Transposes two axes of a symbolic tensor.

For more information, see transpose().

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to transpose.
  • axis_1 (int) – One of the two axes to transpose. If negative, this indexes from the end of the tensor. For example, transpose(v, -1, -2) transposes the last two axes.
  • axis_2 (int) – The other axis to transpose. May also be negative to index from the end of the tensor.

Returns:

A new symbolic tensor with the two specified axes transposed. It has the same elements and dtype, but the order of the elements is different according to the transposition.

Return type:

TensorValue

trunc()

max.graph.ops.trunc(x)

source

Parameters:

x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)

Return type:

TensorValue

unsqueeze()

max.graph.ops.unsqueeze(x, axis)

source

Inserts a size-1 dimension into a symbolic tensor.

Parameters:

  • x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to unsqueeze.
  • axis (int) – The index at which to insert a new dimension into the input’s shape. Elements at that index or higher are shifted back. If negative, it indexes relative 1 plus the rank of the tensor. For example, unsqueeze(v, -1) adds a new dimension at the end, and unsqueeze(v, -2) inserts the dimension immediately before the last dimension.

Returns:

A symbolic tensor with the same number of elements as the input tensor, whose rank is 1 larger than the rank of the input tensor. The result’s shape at the axis dimension is a static dimension of size 1.

Return type:

TensorValue

where()

max.graph.ops.where(condition, x, y)

source

Returns element-wise condition ? x : y for input tensors condition, x, and y.

Parameters:

Returns:

A new symbolic tensor holding either values from either x or y, based on the elements in condition.

Return type:

TensorValue

while_loop()

max.graph.ops.while_loop(initial_values, predicate, body)

source

Execute a loop until the predicate evaluates to false.

Both the predicate and body functions must take in as arguments the same number and types of values as specified in the init_args. The predication function must return only a boolean scalar tensor of type bool. The body function must return a list of values matching the types of init_args, (or may return a value directly if there is only one).

The following example demonstrates a basic while loop with a single argument:

from max.graph import Graph, ops
from max.dtype import DType

with Graph("while_loop_example") as g:
    x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())

    def pred(x):
        return x < 10

    def body(x):
        return x + 1

    result = ops.while_loop(x, pred, body)
    print(result)

The following example shows a while loop with multiple arguments:

from max.graph import Graph, ops
from max.dtype import DType

with Graph("while_loop_example") as g:
    x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())
    y = ops.constant(5, dtype=DType.int32, device=DeviceRef.CPU())

    def pred(x, y):
        return ops.logical_and(x < 10, y < 15)

    def body(x, y):
        return [x + 1, y + 1]

    results = ops.while_loop((x, y), pred, body)
    print(results)

Parameters:

  • initial_values (Iterable[Value[Any]] | Value[Any]) – Initial values for loop arguments. Must be non-empty.
  • predicate (Callable[[...], TensorValue]) – Callable that takes loop arguments and returns a boolean scalar tensor of type bool determining loop continuation.
  • body (Callable[[...], Value[Any] | Iterable[Value[Any]]]) – Callable that takes loop arguments and returns updated values matching the types of init_args.

Returns:

List of output values from the final loop iteration.

Raises:

Return type:

list[TensorValue]