Python module
max.experimental.functional
Distributed functional API — PyTree-centric rule-based SPMD dispatch.
Explicit per-op SPMD dispatch via spmd_dispatch.
Creation and random ops are standalone (no tensor inputs).
Usage:
from max.experimental import functional as F
y = F.matmul(a, b)
z = F.add(x, y)
w = F.transfer_to(z, new_mapping)- All logic lives in submodules:
utils.py— shared helperscollective_ops.py— collectives + transfer_tospmd_ops.py— spmd_dispatch engine + explicit op functionscreation_ops.py— full/ones/zeros/uniform/gaussian
Custom ops (custom, inplace_custom) are defined here because
they combine graph ops with extension loading and don’t fit a submodule.
Any
class max.experimental.functional.Any(*args, **kwargs)
Bases: object
Special type indicating an unconstrained type.
- Any is compatible with every type.
- Any assumed to have all methods.
- All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
Coroutine
class max.experimental.functional.Coroutine
Bases: Awaitable
close()
close()
Raise GeneratorExit inside coroutine.
send()
abstract send(value)
Send a value into the coroutine. Return next yielded value or raise StopIteration.
throw()
abstract throw(typ, val=None, tb=None)
Raise an exception in the coroutine. Return next yielded value or raise StopIteration.
DType
class max.experimental.functional.DType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Bases: Enum
The tensor data type.
align
property align
Returns the alignment requirement of the data type in bytes.
The alignment specifies the memory boundary that values of this data type must be aligned to for optimal performance and correctness.
bfloat16
bfloat16 = 80
16-bit bfloat16 (Brain Float) format. 1 sign bit, 8 exponent bits, 7 mantissa bits.
bool
bool = 1
Boolean data type. Stores True or False values.
float16
float16 = 79
16-bit IEEE 754 half-precision floating-point. 1 sign bit, 5 exponent bits, 10 mantissa bits.
float32
float32 = 81
32-bit IEEE 754 single-precision floating-point. 1 sign bit, 8 exponent bits, 23 mantissa bits.
float4_e2m1fn
float4_e2m1fn = 64
4-bit floating-point with 2 exponent bits and 1 mantissa bits, finite values only.
float64
float64 = 82
64-bit IEEE 754 double-precision floating-point. 1 sign bit, 11 exponent bits, 52 mantissa bits.
float8_e4m3fn
float8_e4m3fn = 75
8-bit floating-point with 4 exponent bits and 3 mantissa bits, finite values only.
float8_e4m3fnuz
float8_e4m3fnuz = 76
8-bit floating-point with 4 exponent bits and 3 mantissa bits, finite values only, no negative zero.
float8_e5m2
float8_e5m2 = 77
8-bit floating-point with 5 exponent bits and 2 mantissa bits.
float8_e5m2fnuz
float8_e5m2fnuz = 78
8-bit floating-point with 5 exponent bits and 2 mantissa bits, finite values only, no negative zero.
float8_e8m0fnu
float8_e8m0fnu = 73
8-bit floating-point with 8 exponent bits and 0 mantissa bits, finite values only.
from_numpy()
from_numpy()
Converts a NumPy dtype to the corresponding DType.
-
Parameters:
-
dtype (np.dtype) – The NumPy dtype to convert.
-
Returns:
-
The corresponding DType enum value.
-
Return type:
-
Raises:
-
ValueError – If the input dtype is not supported.
int16
int16 = 137
16-bit signed integer, range -32,768 to 32,767.
int32
int32 = 139
32-bit signed integer, range -2,147,483,648 to 2,147,483,647.
int64
int64 = 141
64-bit signed integer, range -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807.
int8
int8 = 135
8-bit signed integer, range -128 to 127.
is_float()
is_float(self) → bool
Checks if the data type is a floating-point type.
is_float8()
is_float8(self) → bool
Checks if the data type is an 8-bit floating-point type.
is_half()
is_half(self) → bool
Checks if the data type is a half-precision floating-point type.
is_integral()
is_integral(self) → bool
Checks if the data type is an integer type.
is_signed_integral()
is_signed_integral(self) → bool
Checks if the data type is a signed integer type.
is_unsigned_integral()
is_unsigned_integral(self) → bool
Checks if the data type is an unsigned integer type.
size_in_bits
property size_in_bits
Returns the size of the data type in bits.
This indicates how many bits are required to store a single value of this data type in memory.
size_in_bytes
property size_in_bytes
Returns the size of the data type in bytes.
This indicates how many bytes are required to store a single value of this data type in memory.
to_numpy()
to_numpy()
Converts this DType to the corresponding NumPy dtype.
-
Returns:
-
The corresponding NumPy dtype object.
-
Return type:
-
Raises:
-
ValueError – If the dtype is not supported.
-
Parameters:
-
self (DType)
uint16
uint16 = 136
16-bit unsigned integer, range 0 to 65,535.
uint32
uint32 = 138
32-bit unsigned integer, range 0 to 4,294,967,295.
uint64
uint64 = 140
64-bit unsigned integer, range 0 to 18,446,744,073,709,551,615.
uint8
uint8 = 134
8-bit unsigned integer, range 0 to 255.
Device
class max.experimental.functional.Device
Bases: object
Represents a compute device available for tensor operations.
This is the base class for CPU and Accelerator.
Do not instantiate this class directly; use CPU for host
devices or Accelerator for GPU devices.
from max import driver
cpu = driver.CPU()
gpu = driver.Accelerator()api
property api
Returns the API used to program the device.
Possible values are:
cpufor host devices.cudafor NVIDIA GPUs.hipfor AMD GPUs.
from max import driver
device = driver.CPU()
device.apiarchitecture_name
property architecture_name
Returns the architecture name of the device.
Examples of possible values:
gfx90a,gfx942for AMD GPUs.sm_80,sm_86for NVIDIA GPUs.- CPU devices raise an exception.
from max import driver
device = driver.Accelerator()
device.architecture_namecan_access()
can_access(self, other: max.driver.Device) → bool
Checks if this device can directly access memory of another device.
from max import driver
gpu0 = driver.Accelerator(id=0)
gpu1 = driver.Accelerator(id=1)
if gpu0.can_access(gpu1):
print("GPU0 can directly access GPU1 memory.")cpu
cpu = <nanobind.nb_func object>
default_stream
property default_stream
Returns the default stream for this device.
The default stream is initialized when the device object is created.
-
Returns:
-
The default execution stream for this device.
-
Return type:
id
property id
Returns a zero-based device id.
For a CPU device this is always 0.
For GPU accelerators this is the id of the device relative to this host.
Along with the label, an id can uniquely identify a device,
e.g. gpu:0, gpu:1.
from max import driver
device = driver.Accelerator()
device_id = device.id-
Returns:
-
The device ID.
-
Return type:
is_compatible
property is_compatible
Returns whether this device is compatible with MAX.
-
Returns:
-
True if the device is compatible with MAX, False otherwise.
-
Return type:
is_host
property is_host
Whether this device is the CPU (host) device.
from max import driver
device = driver.CPU()
device.is_hostlabel
property label
Returns device label.
Possible values are:
cpufor host devices.gpufor accelerators.
from max import driver
device = driver.CPU()
device.labelstats
property stats
Returns utilization data for the device.
from max import driver
device = driver.CPU()
stats = device.stats-
Returns:
-
A dictionary containing device utilization statistics.
-
Return type:
synchronize()
synchronize(self) → None
Ensures all operations on this device complete before returning.
-
Raises:
-
ValueError – If any enqueued operations had an internal error.
DeviceRef
class max.experimental.functional.DeviceRef(device_type, id=0)
Bases: object
A symbolic device representation.
DeviceRef type representation consists of a DeviceKind and an id. This is a direct representation of the device attribute in MLIR.
The following example demonstrates how to create and use device references:
from max.graph import DeviceRef
# Create a GPU device reference (default id=0)
gpu_device = DeviceRef.GPU()
print(gpu_device) # Outputs: gpu:0
# Create a CPU device with specific id
cpu_device = DeviceRef.CPU(id=1)
print(cpu_device) # Outputs: cpu:1-
Parameters:
-
- device_type (DeviceKind)
- id (int)
CPU()
static CPU(id=0)
Creates a CPU device reference.
GPU()
static GPU(id=0)
Creates a GPU device reference.
device_type
device_type: DeviceKind
from_device()
static from_device(device)
Converts a Device or DeviceRef to a DeviceRef.
from_mlir()
static from_mlir(attr)
Returns a device reference from an MLIR attribute.
-
Parameters:
-
attr (DeviceRefAttr)
-
Return type:
id
id: int
is_cpu()
is_cpu()
Returns True if the device is a CPU device.
-
Return type:
is_gpu()
is_gpu()
Returns True if the device is a GPU device.
-
Return type:
to_device()
to_device()
Converts a device reference to a concrete driver Device.
-
Return type:
to_mlir()
to_mlir()
Returns an MLIR attribute representing the device.
-
Return type:
-
DeviceRefAttr
Graph
class max.experimental.functional.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], kernel_library=None, module=None, strict_device_placement=DevicePlacementPolicy.Warn, **kwargs)
Bases: object
Represents a single MAX graph.
A Graph defines a model’s computation. You build a graph by
composing operations that describe how input tensors are transformed into
outputs. Unlike imperative code that executes operations, a Graph
captures the data flow between operations, which allows MAX to optimize and
parallelize execution at compile time. Operations run on the compiled object.
The following code examples show two different strategies for constructing graphs.
Use the context manager: Use Graph as a context manager to
define the active graph. Inside the with block, retrieve inputs from
inputs, call ops to build nodes, and set the graph output with
output(). Ops called inside the block find the active graph
automatically. Ops called outside the block fail because there is no active
graph.
from max.dtype import DType
from max.graph import DeviceRef, Graph, TensorType, Weight
W = Weight("W", DType.float32, [3, 2], DeviceRef.CPU())
b = Weight("b", DType.float32, [2], DeviceRef.CPU())
with Graph(
"linear_relu",
input_types=[TensorType(DType.float32, ["batch", 3], device=DeviceRef.CPU())],
) as graph:
x = graph.inputs[0].tensor
y = x @ W + b
graph.output(y)Use the graph constructor: Pass a callable as the forward argument.
The graph automatically passes the input TensorValue to the
callable and records the return value as the graph output. Under the hood,
this still opens and closes a graph context.
from max.dtype import DType
from max.graph import DeviceRef, Graph, TensorType, TensorValue, Weight, ops
class Linear:
def __init__(self, in_dim: int, out_dim: int):
self.weight = Weight("W", DType.float32, [in_dim, out_dim], DeviceRef.CPU())
self.bias = Weight("b", DType.float32, [out_dim], DeviceRef.CPU())
def __call__(self, x: TensorValue) -> TensorValue:
return ops.matmul(x, self.weight) + self.bias
linear_layer = Linear(2, 2)
graph = Graph(
"linear",
linear_layer,
input_types=[TensorType(DType.float32, (2,), DeviceRef.CPU())],
)These examples only use the max.graph package, but most models also
use Module and other building blocks from max.nn.
To learn more, see Build a model graph with Module.
-
Parameters:
-
- name (str) – A name for the graph.
- forward (Callable[..., None | Value[Any] | Iterable[Value[Any]]] | None) – The sequence of graph ops for the forward pass (inference).
- input_types (Iterable[Type[Any]]) – A sequence of
Typeinstances that describe each graph input. These are typicallyTensorTypeinstances. You can also includeBufferTypeinstances for mutable in-place inputs. - path (Path | None) – The path to a saved graph (internal use only).
- custom_extensions (Iterable[Path]) – The extensions to load for the model. Supports paths
to
.mojopkgor.mojosources with custom ops. - kernel_library (KernelLibrary | None) – Optional pre-built kernel library to use. Defaults to
None(a new library is created fromcustom_extensionsif needed). - module (mlir.Module | None) – Optional existing MLIR module (internal use only). Defaults to
None. - strict_device_placement (DevicePlacementPolicy)
add_subgraph()
add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[], devices=[])
Creates a reusable subgraph for the current graph.
A subgraph is the graph equivalent of a function: you define a block of ops once and call it from the parent graph as many times as you need. Use a subgraph when a block of computation repeats, for example, a transformer layer that appears 62 times in a model. Wrapping it in a subgraph lets the compiler process the definition once instead of once per repetition, which can cut compile time by 50x or more.
Trade-offs to keep in mind:
- Memory: Allocations inside a subgraph can’t be shared with allocations outside it, so peak memory may be slightly higher.
- Kernel fusion: The compiler can’t fuse ops across the subgraph boundary, which may reduce throughput marginally.
For models with a Module, prefer
build_subgraph(), which handles weight prefixes
automatically.
Examples:
Define a subgraph that adds 1 to every element, then call it on a graph input:
from max.dtype import DType
from max.graph import Graph, ops
from max.graph.type import TensorType, DeviceRef
input_type = TensorType(DType.float32, [10], DeviceRef.CPU())
with Graph("main", input_types=[input_type]) as graph:
with graph.add_subgraph(
"add_one", input_types=[input_type]
) as sub:
x = sub.inputs[0].tensor
one = ops.constant(1, DType.float32, device=DeviceRef.CPU())
sub.output(ops.elementwise.add(x, one))
result = ops.call(sub, graph.inputs[0])
graph.output(*result)-
Parameters:
-
- name (str) – The name identifier for the subgraph. Must be unique within
the parent graph. Use the same name when calling the subgraph
with
call(). - forward (Callable[[...], None | Value[Any] | Iterable[Value[Any]]] | None) – An optional callable that defines the subgraph’s forward pass. When provided, the subgraph is built immediately.
- input_types (Iterable[Type[Any]]) – The tensor types for the subgraph’s inputs. A chain type is added automatically for operation sequencing.
- path (Path | None) – An optional path to a saved subgraph definition to load from disk.
- custom_extensions (Iterable[Path]) – Paths to custom op libraries (
.mojopkgfiles or Mojo source directories) to load for the subgraph. - devices (Iterable[DeviceRef]) – Devices this subgraph targets.
- name (str) – The name identifier for the subgraph. Must be unique within
the parent graph. Use the same name when calling the subgraph
with
-
Returns:
-
A
Graphinstance registered as a subgraph of this graph. -
Return type:
add_weight()
add_weight(weight, force_initial_weight_on_host=True)
Adds a weight to the graph.
If the weight is in the graph already, return the existing value.
-
Parameters:
-
Returns:
-
A
TensorValuethat contains this weight. -
Raises:
-
ValueError – If a weight with the same name already exists in the graph.
-
Return type:
always_ready_chain
property always_ready_chain: _ChainValue
A graph-global, immutable chain that is always ready.
Created once per graph and never advanced/merged by the graph itself. Use it for operations that are safe to schedule without threading per-device ordering (for example, host→device transfers for staging).
current
current
debug
debug = <max.graph.graph.GraphDebugConfig object>
device_chains
device_chains: _DeviceChainMap
empty_module()
static empty_module()
Create a new module to hold one or more graphs.
-
Return type:
-
Module
inputs
The input values of the graph.
-
Returns:
-
A sequence of
Valueobjects corresponding to theinput_typespassed at construction, excluding internal chain values.
kernel_libraries_paths
Returns the list of extra kernel libraries paths for the custom ops.
output()
output(*outputs)
Sets the output values of the graph and finalizes construction.
Call this once after building all ops. The graph can’t be executed
until output() has been called. Subsequent calls to
output_types read back the types of the values passed here.
Examples:
Build a graph that doubles its input and set the output:
from max.dtype import DType
from max.graph import DeviceRef, Graph, ops
from max.graph.type import TensorType
input_type = TensorType(DType.float32, [4], DeviceRef.CPU())
with Graph("double", input_types=[input_type]) as graph:
x = graph.inputs[0].tensor
two = ops.constant(2.0, DType.float32, device=DeviceRef.CPU())
graph.output(ops.elementwise.mul(x, two))-
Parameters:
-
outputs (Value[Any] | Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The output values of the graph. Each value may be a
Valueor anyTensorValueLike. -
Return type:
-
None
output_types
The types of the graph output values.
MLIRThreadPoolExecutor
class max.experimental.functional.MLIRThreadPoolExecutor(max_workers=None, thread_name_prefix='', initializer=None, initargs=())
Bases: ThreadPoolExecutor
Initializes a new ThreadPoolExecutor instance.
-
Parameters:
-
- max_workers – The maximum number of threads that can be used to execute the given calls.
- thread_name_prefix – An optional name prefix to give our threads.
- initializer – A callable used to initialize worker threads.
- initargs – A tuple of arguments to pass to the initializer.
submit()
submit(fn, /, *args, **kwargs)
Submits a callable to be executed with the given arguments.
Schedules the callable to be executed as fn(
*args,
**kwargs) and returns a Future instance representing the execution of the callable.
-
Returns:
-
A Future representing the given call.
-
Parameters:
-
- fn (Callable[[~P], R])
- args (~P)
- kwargs (~P)
-
Return type:
-
Future[R]
Mapping
class max.experimental.functional.Mapping
Bases: Collection
A Mapping is a generic container for associating key/value pairs.
This class provides concrete generic implementations of all methods except for __getitem__, __iter__, and __len__.
get()
get(k) → D[k] if k in D, else d. d defaults to None.
items()
items() → a set-like object providing a view on D's items
keys()
keys() → a set-like object providing a view on D's keys
values()
values() → an object providing a view on D's values
Path
class max.experimental.functional.Path(*args, **kwargs)
Bases: PathBase, PurePath
PurePath subclass that can make system calls.
Path represents a filesystem path but unlike PurePath, also offers methods to do system calls on path objects. Depending on your system, instantiating a Path will return either a PosixPath or a WindowsPath object. You can also instantiate a PosixPath or WindowsPath directly, but cannot instantiate a WindowsPath on a POSIX system or vice versa.
absolute()
absolute()
Return an absolute version of this path No normalization or symlink resolution is performed.
Use resolve() to resolve symlinks and remove ‘..’ segments.
as_uri()
as_uri()
Return the path as a URI.
chmod()
chmod(mode, *, follow_symlinks=True)
Change the permissions of the path, like os.chmod().
expanduser()
expanduser()
Return a new path with expanded ~ and ~user constructs (as returned by os.path.expanduser)
from_uri()
classmethod from_uri(uri)
Return a new path from the given ‘file’ URI.
glob()
glob(pattern, *, case_sensitive=None, recurse_symlinks=False)
Iterate over this subtree and yield all existing files (of any kind, including directories) matching the given relative pattern.
group()
group(*, follow_symlinks=True)
Return the group name of the file gid.
hardlink_to()
hardlink_to(target)
Make this path a hard link pointing to the same file as target.
Note the order of arguments (self, target) is the reverse of os.link’s.
is_junction()
is_junction()
Whether this path is a junction.
is_mount()
is_mount()
Check if this path is a mount point
iterdir()
iterdir()
Yield path objects of the directory contents.
The children are yielded in arbitrary order, and the special entries ‘.’ and ‘..’ are not included.
mkdir()
mkdir(mode=511, parents=False, exist_ok=False)
Create a new directory at this given path.
open()
open(mode='r', buffering=-1, encoding=None, errors=None, newline=None)
Open the file pointed to by this path and return a file object, as the built-in open() function does.
owner()
owner(*, follow_symlinks=True)
Return the login name of the file owner.
read_text()
read_text(encoding=None, errors=None, newline=None)
Open the file in text mode, read it, and close the file.
readlink()
readlink()
Return the path to which the symbolic link points.
rename()
rename(target)
Rename this path to the target path.
The target path may be absolute or relative. Relative paths are interpreted relative to the current working directory, not the directory of the Path object.
Returns the new Path instance pointing to the target path.
replace()
replace(target)
Rename this path to the target path, overwriting if that path exists.
The target path may be absolute or relative. Relative paths are interpreted relative to the current working directory, not the directory of the Path object.
Returns the new Path instance pointing to the target path.
resolve()
resolve(strict=False)
Make the path absolute, resolving all symlinks on the way and also normalizing it.
rglob()
rglob(pattern, *, case_sensitive=None, recurse_symlinks=False)
Recursively yield all existing files (of any kind, including directories) matching the given relative pattern, anywhere in this subtree.
rmdir()
rmdir()
Remove this directory. The directory must be empty.
stat()
stat(*, follow_symlinks=True)
Return the result of the stat() system call on this path, like os.stat() does.
symlink_to()
symlink_to(target, target_is_directory=False)
Make this path a symlink pointing to the target path. Note the order of arguments (link, target) is the reverse of os.symlink.
touch()
touch(mode=438, exist_ok=True)
Create this file with the given access mode, if it doesn’t exist.
unlink()
unlink(missing_ok=False)
Remove this file or link. If the path is a directory, use rmdir() instead.
walk()
walk(top_down=True, on_error=None, follow_symlinks=False)
Walk the directory tree from this directory, similar to os.walk().
write_text()
write_text(data, encoding=None, errors=None, newline=None)
Open the file in text mode, write to it, and close the file.
Sequence
class max.experimental.functional.Sequence
Bases: Reversible, Collection
All the operations on a read-only sequence.
Concrete subclasses must override __new__ or __init__, __getitem__, and __len__.
count()
count(value) → integer -- return number of occurrences of value
index()
index(value) → integer -- return first index of value.
Raises ValueError if the value is not present.
Supporting start and stop arguments is optional, but recommended.
Tensor
class max.experimental.functional.Tensor(data=None, *, dtype=None, device=None, storage=None, state=None)
Bases: DLPackArray, HasTensorValue
A multi-dimensional array with eager execution and automatic compilation.
The Tensor class provides a high-level interface for numerical computations with automatic compilation and optimization via the MAX runtime. Operations on tensors execute eagerly while benefiting from lazy evaluation and graph-based optimizations behind the scenes.
Key Features:
- Eager execution: Operations execute immediately with automatic compilation.
- Lazy evaluation: Computation may be deferred until results are needed.
- High performance: Uses the Mojo compiler and optimized kernels.
- Familiar API: Supports common array operations and indexing.
- Device flexibility: Works seamlessly across CPU and accelerators.
Creating Tensors:
Create tensors using the constructor, factory methods like ones(),
zeros(), arange(), or from other array libraries via
from_dlpack().
from max.experimental import tensor
# Create tensors from data (like torch.tensor())
x = tensor.Tensor([[1.0, 2.0], [3.0, 4.0]])
y = tensor.Tensor.zeros((2, 3))
# Perform operations
result = x + y # Eager execution with automatic compilation
# Access values
print(result.shape) # (2, 3)
print(result.dtype) # DType.float32Implementation Notes:
Tensors use lazy evaluation internally - they don’t always hold concrete
data in memory. A tensor may be “unrealized” (not yet computed) until its
value is actually needed (e.g., when converting to other formats or calling
item()). This allows the runtime to optimize sequences of
operations efficiently.
Operations on tensors build a computation graph behind the scenes, which is compiled and executed when needed. All illegal operations fail immediately with clear error messages, ensuring a smooth development experience.
Interoperability:
Tensors support the DLPack protocol for zero-copy data exchange with NumPy,
PyTorch, JAX, and other array libraries. Use from_dlpack() to import
arrays and standard DLPack conversion for export.
Creates a tensor from data or from internal storage.
When called with data, constructs a tensor from a scalar, nested
list, or DLPack-compatible array (matching PyTorch’s torch.tensor()
semantics). When called without data, requires exactly one of
storage or state for internal construction.
For DLPack-compatible arrays (NumPy, PyTorch, etc.) the array’s own
dtype is preserved by default; no silent precision conversion
happens. For Python scalars and nested lists, dtype defaults to
DType.float32 on CPU and DType.bfloat16 on accelerators.
from max.experimental.tensor import Tensor
from max.dtype import DType
# Create from scalar
x = Tensor(42, dtype=DType.int32)
# Create from nested list
y = Tensor([[1.0, 2.0], [3.0, 4.0]])
# Create from NumPy array; dtype is inherited from the array
import numpy as np
z = Tensor(np.array([1, 2, 3], dtype=np.int16)) # stays int16-
Parameters:
-
- data (DLPackArray | NestedArray | Number | None) – The value for the tensor. Can be a scalar number, a nested
Python list, or any DLPack-compatible array (NumPy, PyTorch,
etc.). If not provided, exactly one of
storageorstatemust be supplied. - dtype (DType | None) – The data type for the tensor elements. For DLPack arrays
this defaults to the array’s own dtype; passing a conflicting
value raises
ValueError. For Python scalars/lists this defaults toDType.float32on CPU andDType.bfloat16on accelerators. - device (Device | None) – The device where the tensor will be allocated. If not
specified, defaults to an accelerator if available, otherwise
CPU. Only valid when
datais provided. - storage (driver.Buffer | None) – Internal backing buffer for a realized tensor. Mutually
exclusive with
data. - state (RealizationState | None) – Internal realization state for an unrealized tensor. Mutually
exclusive with
data.
- data (DLPackArray | NestedArray | Number | None) – The value for the tensor. Can be a scalar number, a nested
Python list, or any DLPack-compatible array (NumPy, PyTorch,
etc.). If not provided, exactly one of
-
Return type:
T
property T: Tensor
Returns a tensor with the last two dimensions transposed.
This is equivalent to calling transpose(-1, -2), which swaps
the last two dimensions of the tensor. For a 2D matrix, this produces
the standard matrix transpose.
from max.experimental.tensor import Tensor
from max.dtype import DType
# Create a 2x3 matrix
x = Tensor([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3)]
# Use .T property (equivalent to transpose(-1, -2))
y = x.T
print(f"Transposed shape: {y.shape}")
# Output: Transposed shape: [Dim(3), Dim(2)]
print(y)-
Returns:
-
A tensor with the last two dimensions transposed.
arange()
classmethod arange(start=0, stop=None, step=1, out_dim=None, *, dtype=None, device=None)
Creates a tensor with evenly spaced values within a given interval.
Returns a new 1D tensor containing a sequence of values starting from
start (inclusive) and ending before stop (exclusive), with values
spaced by step. This is similar to Python’s built-in range()
function and NumPy’s arange().
from max.experimental import tensor
from max.dtype import DType
# Create a range from 0 to 10 (exclusive)
x = tensor.Tensor.arange(10)
# Result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# Create a range from 5 to 15 with step 2
y = tensor.Tensor.arange(5, 15, 2)
# Result: [5, 7, 9, 11, 13]
# Use a specific dtype
z = tensor.Tensor.arange(0, 5, dtype=DType.float32)
# Result: [0.0, 1.0, 2.0, 3.0, 4.0]
# Create a range with float step (like numpy/pytorch)
w = tensor.Tensor.arange(0.0, 1.0, 0.2)
# Result: [0.0, 0.2, 0.4, 0.6, 0.8]
# Create a descending range with negative step
v = tensor.Tensor.arange(5, 0, -1, dtype=DType.float32)
# Result: [5.0, 4.0, 3.0, 2.0, 1.0]-
Parameters:
-
- start (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The starting value of the sequence. If
stopis not provided, this becomes thestopvalue andstartdefaults to 0. - stop (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – The end value of the sequence (exclusive). If not specified,
the sequence ends at
startand begins at 0. - step (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The spacing between values in the sequence. Must be non-zero.
- out_dim (int | str | Dim | integer[Any] | TypedAttr | None) – The expected output dimension. Required when
start,stop, orstepare tensors rather than scalar literals. If not specified, the output dimension is computed from the scalar values of the inputs. - dtype (DType | None) – The data type for the tensor elements. If not specified,
defaults to
DType.float32for CPU devices andDType.bfloat16for accelerator devices. - device (Device | DeviceMapping | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
- start (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The starting value of the sequence. If
-
Returns:
-
A 1D tensor containing the evenly spaced values.
-
Return type:
argmax()
argmax(axis=-1)
Finds the indices of the maximum values along an axis.
Returns a tensor containing the indices of the maximum values along the specified axis. This is useful for finding the position of the largest element, such as determining predicted classes in classification.
from max.experimental import tensor
# Create a 2x4 tensor
x = tensor.Tensor(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)
# Find argmax along last axis (within each row)
indices = x.argmax(axis=-1)
# Result: [1, 2] (index 1 in first row, index 2 in second row)
# Find argmax over all elements
index = x.argmax(axis=None)
# Result: 6 (flattened index of maximum value 4.2)broadcast_to()
broadcast_to(shape)
Broadcasts the tensor to the specified shape.
Returns a tensor broadcast to the target shape, following NumPy broadcasting semantics. Dimensions of size 1 in the input can be expanded to match larger dimensions in the target shape.
This is equivalent to PyTorch’s torch.broadcast_to() and
torch.Tensor.expand().
from max.experimental import tensor
# Create a tensor with shape (3, 1)
x = tensor.Tensor.ones([3, 1])
# Broadcast to (3, 4) - expands the second dimension
y = x.broadcast_to([3, 4])
print(y.shape) # (3, 4)
# Add a new leading dimension
w = x.broadcast_to([2, 3, 1])
print(w.shape) # (2, 3, 1)buffers
The underlying per-shard driver buffers.
Returns one buffer for non-distributed tensors, N buffers for a distributed tensor with N shards.
-
Raises:
-
TypeError – If the tensor is unrealized (lazy/symbolic).
cast()
cast(dtype)
Casts the tensor to a different data type.
Returns a new tensor with the same values but a different data type.
This is useful for type conversions between different numeric types,
such as converting float32 to int32 for indexing operations or
float32 to bfloat16 for memory-efficient computations.
from max.experimental import tensor
from max.dtype import DType
# Create a float32 tensor
x = tensor.Tensor([1.7, 2.3, 3.9], dtype=DType.float32)
print(x.dtype) # DType.float32
# Cast to int32 (truncates decimal values)
y = x.cast(DType.int32)
print(y.dtype) # DType.int32
# Values: [1, 2, 3]clip()
clip(*, min=None, max=None)
Clips values outside a range to the boundaries of the range.
from max.experimental import tensor
# Create a 2x4 tensor
x = tensor.Tensor(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)
# Find max along last axis (within each row)
clipped_above = x.clip(max=3.)
# Result: [[1.2, 3., 2.1, 0.8], [2.3, 1.9, 3, 3.]]
clipped_below = x.clip(min=3.)
# Result: [[3., 3.5, 3., 3.], [3., 3., 4.2, 3.]]-
Parameters:
-
- min (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – The minimum value of the range. If not specified, do not clip values for being too small.
- max (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – The maximum value of the range. If not specified, do not clip values for being too large.
-
Returns:
-
A tensor containing the values clipped to the specified range.
-
Return type:
constant()
classmethod constant(value, *, dtype=None, device=None)
Creates a tensor from a scalar, array, or nested list.
-
Parameters:
-
- value (DLPackArray | Sequence[float | number[Any] | Sequence[Number | NestedArray]] | float | number[Any]) – The constant value for the tensor. Can be a scalar number, a nested Python list, or any DLPack-compatible array.
- dtype (DType | None) – The data type for the tensor elements. If not specified,
defaults to
DType.float32for CPU devices andDType.bfloat16for accelerator devices. - device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
-
Returns:
-
A new tensor containing the constant value(s).
-
Return type:
device
property device: Device
Gets the device where the tensor is stored.
Returns the device (CPU or accelerator) where the tensor’s data is located. Raises for distributed tensors that span multiple devices.
-
Returns:
-
The device where the tensor is stored.
-
Return type:
driver_tensor
property driver_tensor: Buffer
A pointer to the underlying memory.
Raises if the tensor is unrealized or sharded.
dtype
property dtype: DType
Gets the data type of the tensor elements.
-
Returns:
-
The data type of the tensor elements.
-
Return type:
from_dlpack()
classmethod from_dlpack(array)
Creates a tensor from a DLPack array.
Constructs a tensor by importing data from any object that supports the DLPack protocol (such as NumPy arrays and PyTorch tensors). This enables zero-copy interoperability with other array libraries.
import numpy as np
from max.experimental import tensor
# Create a NumPy array
np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
# Convert to MAX tensor via DLPack
x = tensor.Tensor.from_dlpack(np_array)-
Parameters:
-
array (DLPackArray) – Any object supporting the DLPack protocol, such as NumPy arrays, PyTorch tensors, or JAX arrays.
-
Returns:
-
A new tensor containing the data from the DLPack array.
-
Return type:
from_graph_value()
classmethod from_graph_value(value)
Creates a tensor from a graph value.
Constructs a tensor from an existing graph value, which can be either
a TensorValue or BufferValue. This
is used for converting graph level values into tensor objects.
The new tensor is registered as unrealized, backed by the current
realization context.
from_shard_values()
classmethod from_shard_values(shard_values, mapping=None)
Creates a tensor from one or more per-shard graph values.
For a single shard value with no mapping, behaves like
from_graph_value(). For multiple shard values, a
DeviceMapping is required
and the result is a distributed tensor.
-
Parameters:
-
- shard_values (Sequence[BufferValue | TensorValue]) – Per-device graph values (TensorValue or BufferValue). One per device in the mesh.
- mapping (DeviceMapping | None) – Device mapping describing how shards map to mesh
devices and their placements. Required when
len(shard_values) > 1.
-
Returns:
-
A tensor backed by the provided shard values.
-
Raises:
-
- ValueError – If multiple shard values are given without a mapping.
- TypeError – If any shard value is not a graph value.
-
Return type:
full()
classmethod full(shape, value, *, dtype=None, device=None)
Creates a tensor filled with a specified value.
Returns a new tensor with the given shape where all elements are initialized to the specified value. This is useful for creating tensors with uniform values other than zero or one.
from max.experimental import tensor
from max.dtype import DType
# Create a 3x3 tensor filled with 7
x = tensor.Tensor.full((3, 3), value=7, dtype=DType.int32)
# Create a 2x4 tensor filled with pi
y = tensor.Tensor.full((2, 4), value=3.14159)-
Parameters:
-
- shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape.
- value (float | number[Any]) – The scalar value to fill the tensor with.
- dtype (DType | None) – The data type for the tensor elements. If not specified,
defaults to
DType.float32for CPU devices andDType.bfloat16for accelerator devices. - device (Device | DeviceMapping | None) – The device or device mapping where the tensor will be
allocated. If not specified, defaults to an accelerator if
available, otherwise CPU. Pass a
DeviceMappingto create a distributed tensor.
-
Returns:
-
A new tensor with the specified shape filled with the given value.
-
Return type:
full_like()
classmethod full_like(input, value)
Creates a tensor filled with a value, matching a given tensor’s properties.
Returns a new tensor filled with the specified value that matches the
shape, data type, and device of the input tensor. This behaves like
NumPy’s full_like and PyTorch’s full_like.
from max.experimental import tensor
# Create a reference tensor
ref = tensor.Tensor.ones([2, 3])
# Create tensor filled with 5.0 matching the reference tensor
x = tensor.Tensor.full_like(ref, value=5.0)-
Parameters:
-
- input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.
- value (float | number[Any]) – The scalar value to fill the tensor with.
-
Returns:
-
A new tensor filled with the specified value, matching the properties of the input.
-
Return type:
graph_values
property graph_values: tuple[BufferValue | TensorValue, ...]
Returns per-shard graph values directly from the realization state.
For unrealized tensors (both distributed and single-device), returns
the underlying GraphValue``s (``TensorValue | BufferValue) without
wrapping in intermediate Tensor objects.
For realized tensors, creates graph values via __tensorvalue__()
on each shard.
This is the primary way to access graph-level shard values for custom dispatch rules and SPMD loops.
is_distributed
property is_distributed: bool
Returns True if this tensor spans multiple devices.
item()
item()
Gets the scalar value from a single-element tensor.
Extracts and returns the scalar value from a tensor containing exactly one element. The tensor is realized if needed and transferred to CPU before extracting the value.
For replicated distributed tensors, the value is read from the first shard (all shards hold identical data).
-
Returns:
-
The scalar value from the tensor. The return type matches the tensor’s dtype (e.g., float for float32, int for int32).
-
Raises:
-
- TypeError – If the tensor contains more than one element.
- ValueError – If the tensor is distributed and not fully replicated.
-
Return type:
local_shards
Returns per-device shard views as independent unsharded Tensors.
Each returned Tensor is a lightweight, standalone, unsharded Tensor
backed by a single shard’s storage or graph value. They can be
passed directly to F.* ops or used as Module parameters.
For realized sharded tensors, each shard wraps one driver.Buffer.
For unrealized sharded tensors, each shard wraps one GraphValue
from the shared RealizationState.
For unsharded tensors, returns a 1-tuple containing self.
mapping
property mapping: DeviceMapping
Returns the device mapping describing where this tensor lives.
materialize()
materialize()
Gather a distributed tensor into a single local tensor.
Allreduces Partial axes, allgathers Sharded axes, and transfers
the result to CPU. Returns self unchanged for non-distributed
tensors.
-
Return type:
max()
max(axis=-1)
Computes the maximum values along an axis.
Returns a tensor containing the maximum values along the specified axis. This is useful for reduction operations and finding peak values in data.
from max.experimental import tensor
# Create a 2x4 tensor
x = tensor.Tensor(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)
# Find max along last axis (within each row)
row_max = x.max(axis=-1)
# Result: [3.5, 4.2]
# Find max along first axis (within each column)
col_max = x.max(axis=0)
# Result: [2.3, 3.5, 4.2, 3.1]
# Find max over all elements
overall_max = x.max(axis=None)
# Result: 4.2 (maximum value across all elements)mean()
mean(axis=-1)
Computes the mean values along an axis.
Returns a tensor containing the arithmetic mean of values along the specified axis. This is useful for computing averages, normalizing data, or aggregating statistics.
from max.experimental import tensor
# Create a 2x4 tensor
x = tensor.Tensor(
[[2.0, 4.0, 6.0, 8.0], [1.0, 3.0, 5.0, 7.0]],
)
# Compute mean along last axis (within each row)
row_mean = x.mean(axis=-1)
# Result: [5.0, 4.0] (mean of each row)
# Compute mean along first axis (within each column)
col_mean = x.mean(axis=0)
# Result: [1.5, 3.5, 5.5, 7.5] (mean of each column)
# Compute mean over all elements
overall_mean = x.mean(axis=None)
# Result: 4.5 (mean of all elements)mesh
property mesh: DeviceMesh
Returns the device mesh.
min()
min(axis=-1)
Computes the minimum values along an axis.
Returns a tensor containing the minimum values along the specified axis. This is useful for reduction operations and finding the smallest values in data.
from max.experimental import tensor
# Create a 2x4 tensor
x = tensor.Tensor(
[[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)
# Find min along last axis (within each row)
row_min = x.min(axis=-1)
# Result: [0.8, 1.9]
# Find min along first axis (within each column)
col_min = x.min(axis=0)
# Result: [1.2, 1.9, 2.1, 0.8]
# Find min over all elements
overall_min = x.min(axis=None)
# Result: 0.8 (minimum value across all elements)num_elements()
num_elements()
Gets the total number of elements in the tensor.
Computes the product of all dimensions in the tensor’s shape to determine the total number of elements.
-
Returns:
-
The total number of elements in the tensor.
-
Return type:
num_shards
property num_shards: int
Returns the number of shards (1 for an unsharded tensor).
ones()
classmethod ones(shape, *, dtype=None, device=None)
Creates a tensor filled with ones.
Returns a new tensor with the specified shape where all elements are initialized to one.
from max.experimental import tensor
# Create a 2x3 tensor of ones
x = tensor.Tensor.ones((2, 3))-
Parameters:
-
- shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The shape of the output tensor.
- dtype (DType | None) – The data type for the tensor elements. If not specified,
defaults to
DType.float32for CPU devices andDType.bfloat16for accelerator devices. - device (Device | DeviceMapping | None) – The device or device mapping where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
-
Returns:
-
A new tensor with the specified shape filled with ones.
-
Return type:
ones_like()
classmethod ones_like(input)
Creates a tensor of ones matching a given tensor’s properties.
Returns a new tensor filled with ones that matches the shape, data type,
and device of the input tensor. This behaves like NumPy’s ones_like
and PyTorch’s ones_like.
from max.experimental import tensor
# Create a reference tensor
ref = tensor.Tensor.zeros([3, 4])
# Create ones tensor matching the reference tensor
x = tensor.Tensor.ones_like(ref)
# Result: 3x4 tensor of ones with dtype float32-
Parameters:
-
input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.
-
Returns:
-
A new tensor filled with ones matching the properties of the input.
-
Return type:
permute()
permute(dims)
Permutes the dimensions of the tensor.
Returns a tensor with its dimensions reordered according to the
specified permutation. This is useful for changing the layout of
multi-dimensional data, such as converting between different tensor
layout conventions (e.g., from [batch, channels, height, width]
to [batch, height, width, channels]).
from max.experimental.tensor import Tensor
from max.dtype import DType
# Create a 3D tensor (batch_size=2, channels=3, length=4)
x = Tensor(
[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]],
dtype=DType.int32,
)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3), Dim(4)]
# Rearrange to (batch, length, channels)
y = x.permute([0, 2, 1])
print(f"Permuted shape: {y.shape}")
# Output: Permuted shape: [Dim(2), Dim(4), Dim(3)]placements
property placements: tuple[Placement, ...]
Returns per-axis placement descriptors.
For NamedMapping,
this converts to placements on the fly. Raises
ConversionError
if the spec contains compiler-only annotations.
prod()
prod(axis=-1)
Computes the product of values along an axis.
range_like()
classmethod range_like(type)
Creates a range tensor matching a given type’s properties.
Returns a new tensor containing sequential indices along the last dimension, broadcasted to match the shape of the specified tensor type. Each row (along the last dimension) contains values from 0 to the dimension size minus one. This is useful for creating position indices or coordinate tensors.
from max.experimental import tensor
from max.graph import TensorType
from max.dtype import DType
# Create a reference tensor type with shape (2, 4)
ref_type = TensorType(DType.int32, (2, 4))
# Create range tensor matching the reference type
x = tensor.Tensor.range_like(ref_type)
# Result: [[0, 1, 2, 3],
# [0, 1, 2, 3]]-
Parameters:
-
type (TensorType) – The tensor type to match. The returned tensor will have the same shape, dtype, and device as this type, with values representing indices along the last dimension.
-
Returns:
-
A new tensor with sequential indices broadcasted to match the input type’s shape.
-
Return type:
rank
property rank: int
Gets the number of dimensions in the tensor.
Returns the rank (number of dimensions) of the tensor. For example, a scalar has rank 0, a vector has rank 1, and a matrix has rank 2.
-
Returns:
-
The number of dimensions in the tensor.
-
Return type:
real
property real: bool
Returns True if this tensor is realized (has concrete storage).
For sharded tensors this is all-or-nothing: either every shard is
realized (_state is None) or none are.
realize
property realize: Tensor
Force the tensor to realize if it is not already.
reshape()
reshape(shape)
Reshapes the tensor to a new shape.
Returns a tensor with the same data but a different shape. The total number of elements must remain the same. This is useful for changing tensor dimensions for different operations, such as flattening a multi-dimensional tensor or converting a 1D tensor into a matrix.
from max.experimental import tensor
from max.dtype import DType
# Create a 2x3 tensor
x = tensor.Tensor([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(x.shape) # (2, 3)
# Flatten to 1D
y = x.reshape((6,))
print(y.shape) # (6,)
# Values: [1, 2, 3, 4, 5, 6]shape
property shape: Shape
Gets the global shape of the tensor.
For sharded tensors this returns the logical global shape (not the per-shard shape). If no explicit global shape was set, it is derived from the first shard’s shape, placements, and mesh.
-
Returns:
-
The shape of the tensor.
-
Return type:
split()
split(split_size_or_sections, axis=0)
Splits the tensor into multiple tensors along a given dimension.
This method supports two modes, matching PyTorch’s behavior:
- If
split_size_or_sectionsis an int, splits into chunks of that size (the last chunk may be smaller if not evenly divisible). - If
split_size_or_sectionsis a list of ints, splits into chunks with exactly those sizes (must sum to the dimension size).
from max.experimental import tensor
# Create a 10x4 tensor
x = tensor.Tensor.ones([10, 4])
# Split into chunks of size 3 (last chunk is size 1)
chunks = x.split(3, axis=0)
# Result: 4 tensors with shapes [3,4], [3,4], [3,4], [1,4]
# Split into exact sizes
chunks = x.split([2, 3, 5], axis=0)
# Result: 3 tensors with shapes [2,4], [3,4], [5,4]squeeze()
squeeze(axis)
Removes a size-1 dimension from the tensor.
Returns a tensor with the specified size-1 dimension removed. This is useful for removing singleton dimensions from tensors after operations that may have added them.
from max.experimental import tensor
# Create a tensor with a size-1 dimension
x = tensor.Tensor.ones([4, 1, 6])
print(x.shape) # (4, 1, 6)
# Squeeze out the size-1 dimension
y = x.squeeze(axis=1)
print(y.shape) # (4, 6)-
Parameters:
-
axis (int) – The dimension to remove from the tensor’s shape. If negative, this indexes from the end of the tensor. The dimension at this axis must have size 1.
-
Returns:
-
A tensor with the specified dimension removed.
-
Return type:
-
Raises:
-
ValueError – If the dimension at the specified axis is not size 1.
state
property state: RealizationState | None
Returns the realization state (unsharded tensors only).
storage
Returns the single backing buffer (unsharded tensors only).
sum()
sum(axis=-1)
Computes the sum of values along an axis.
Returns a tensor containing the sum of values along the specified axis. This is a fundamental reduction operation used for aggregating data, computing totals, and implementing other operations like mean.
from max.experimental import tensor
# Create a 2x3 tensor
x = tensor.Tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
)
# Sum along last axis (within each row)
row_sum = x.sum(axis=-1)
# Result: [6.0, 15.0] (sum of each row)
# Sum along first axis (within each column)
col_sum = x.sum(axis=0)
# Result: [5.0, 7.0, 9.0] (sum of each column)
# Sum over all elements
total = x.sum(axis=None)
# Result: 21.0 (sum of all elements)to()
to(target)
Transfers the tensor to a different device, mesh, or mapping.
This method supports three target types:
- Device: Transfers a single-device tensor to the target device.
For realized tensors, performs a direct driver-level transfer via
to(). For unrealized tensors, inserts atransfer_to()op into the computation graph. - DeviceMapping: Reassigns the tensor’s device mesh and placements.
For single-device mappings, equivalent to
.to(device). For multi-device mappings on an unsharded tensor, distributes the tensor across the mesh using the shard collective. - DeviceMesh: Replaces the device mesh while keeping existing placements. For unsharded tensors targeting a multi-device mesh, creates a fully replicated mapping. For distributed tensors, transfers shards to the new mesh devices.
from max.experimental import tensor
from max.driver import CPU, Accelerator
# Create a tensor on CPU
x = tensor.Tensor.ones((2, 3), device=CPU())
print(x.device) # CPU
# Transfer to accelerator
y = x.to(Accelerator())
print(y.device) # Accelerator(0)
# Same-device transfer is a no-op
z = y.to(y.device)
assert z is y-
Parameters:
-
target (Device | DeviceMesh | DeviceMapping) –
The target for the tensor. Can be:
Device: Target device for transfer.DeviceMesh: New mesh, keeping existing placements (or fully replicated for unsharded tensors).DeviceMapping: New mesh and placements; triggers shard collective for multi-device.
-
Returns:
-
A tensor on the specified target. Returns
selfif no transfer is needed. -
Return type:
to_numpy()
to_numpy()
Convert this tensor to a NumPy array.
Materializes distributed tensors and transfers to CPU if needed.
transpose()
transpose(dim1, dim2)
Returns a tensor that is a transposed version of input.
The given dimensions dim1 and dim2 are swapped.
from max.experimental.tensor import Tensor
from max.dtype import DType
# Create a 2x3 matrix
x = Tensor([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3)]
print(x)
# Transpose dimensions 0 and 1 to get a 3x2 matrix
y = x.transpose(0, 1)
print(f"Transposed shape: {y.shape}")
# Output: Transposed shape: [Dim(3), Dim(2)]
print(y)type
property type: TensorType
Gets the tensor type information.
-
Returns:
-
The type information for the tensor.
-
Return type:
-
Raises:
-
TypeError – If the tensor is distributed.
unsqueeze()
unsqueeze(axis)
Inserts a size-1 dimension into the tensor.
Returns a tensor with a new size-1 dimension inserted at the specified
position. This is the inverse of squeeze() and is useful for
adding dimensions needed for broadcasting or matrix operations.
from max.experimental import tensor
# Create a 1D tensor
x = tensor.Tensor([1.0, 2.0, 3.0])
print(x.shape) # (3,)
# Add dimension at the end
y = x.unsqueeze(axis=-1)
print(y.shape) # (3, 1)
# Add dimension at the beginning
z = x.unsqueeze(axis=0)
print(z.shape) # (1, 3)zeros()
classmethod zeros(shape, *, dtype=None, device=None)
Creates a tensor filled with zeros.
Returns a new tensor with the specified shape where all elements are initialized to zero. The tensor is created with eager execution and automatic compilation.
from max.experimental import tensor
# Create a 2x3 tensor of zeros
x = tensor.Tensor.zeros((2, 3))
# Result: [[0.0, 0.0, 0.0],
# [0.0, 0.0, 0.0]]
# Create a 1D tensor using default dtype and device
y = tensor.Tensor.zeros((5,))-
Parameters:
-
- shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape.
- dtype (DType | None) – The data type for the tensor elements. If not specified,
defaults to
DType.float32for CPU devices andDType.bfloat16for accelerator devices. - device (Device | DeviceMapping | None) – The device or device mapping where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.
-
Returns:
-
A new tensor with the specified shape filled with zeros.
-
Return type:
zeros_like()
classmethod zeros_like(input)
Creates a tensor of zeros matching a given tensor’s properties.
Returns a new tensor filled with zeros that matches the shape, data type,
and device of the input tensor. This behaves like NumPy’s zeros_like
and PyTorch’s zeros_like.
from max.experimental import tensor
# Create a reference tensor
ref = tensor.Tensor.ones([3, 4])
# Create zeros tensor matching the reference tensor
x = tensor.Tensor.zeros_like(ref)
# Result: 3x4 tensor of zeros with dtype float32-
Parameters:
-
input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.
-
Returns:
-
A new tensor filled with zeros matching the properties of the input.
-
Return type:
TensorType
class max.experimental.functional.TensorType(dtype, shape, device, _layout=None)
Bases: _TensorTypeBase[TensorType]
A symbolic tensor type.
Use TensorType to declare the expected dtype, shape, and target
device of tensor values that flow through a graph during model
execution. Unlike an eager tensor, a TensorType holds no data. It is a
purely symbolic description of a value’s type at a specific point in the
computation. The graph compiler uses this information for shape inference
and optimization during graph construction.
The following example shows how to create a tensor type and access its properties:
from max.graph import TensorType, DeviceRef
from max.dtype import DType
# Create a tensor type with float32 elements and static dimensions 2x3
tensor_type = TensorType(DType.float32, (2, 3), device=DeviceRef.CPU())
print(tensor_type.dtype) # Outputs: DType.float32
print(tensor_type.shape) # Outputs: [2, 3]A shape’s dimensions can be static (integers), symbolic (strings), or algebraic (expressions over symbolic dimensions). In each case the rank is known at graph construction time.
Pass TensorType instances to load()
or Module.compile() (experimental) to define the input types of a
graph or model.
-
Parameters:
-
- dtype (DType) – The data type of the tensor elements.
- shape (Shape) – The shape of the tensor, expressed as a
Shape. - device (DeviceRef) – The device the tensor is located on. Use
DeviceRef.CPU()orDeviceRef.GPU()to create a device reference. - _layout (FilterLayout | None)
as_buffer()
as_buffer()
Returns the analogous buffer type.
-
Return type:
from_mlir()
classmethod from_mlir(type)
Constructs a tensor type from an MLIR type.
-
Parameters:
-
type (TensorType) – The MLIR Type to parse into a tensor type.
-
Returns:
-
The tensor type represented by the MLIR Type value.
-
Return type:
to_mlir()
to_mlir()
Converts to an mlir.Type instance.
-
Returns:
-
An
mlir.Typein the specified context. -
Return type:
-
TensorType
Type
class max.experimental.functional.Type
Bases: Generic[MlirType]
The type of any value in a MAX graph.
Every value in the graph has a type, and that type is represented by a Type.
This type may be inspected to get finer-grained types and learn more
about an individual Value.
The following example shows how to work with types in a graph:
from max.graph import Graph, TensorType
from max.dtype import DType
with Graph() as g:
# Create a tensor constant with a specific type
tensor_type = TensorType(DType.float32, [2, 3])
# The type can be inspected to get information about the value
print(f"Tensor element type: {tensor_type.dtype}") # Outputs: DType.float32
print(f"Tensor shape: {tensor_type.shape}") # Outputs: [2, 3]from_mlir()
static from_mlir(t)
Constructs a type from an MLIR type.
to_mlir()
to_mlir()
Converts to an mlir.Type instance.
-
Returns:
-
An
mlir.Typein the specified Context. -
Return type:
-
MlirType
TypeVar
class max.experimental.functional.TypeVar
Bases: object
Type variable.
The preferred way to construct a type variable is via the dedicated syntax for generic functions, classes, and type aliases:
class Sequence[T]: # T is a TypeVar
...This syntax can also be used to create bound and constrained type variables:
# S is a TypeVar bound to str
class StrSequence[S: str]:
...
# A is a TypeVar constrained to str or bytes
class StrOrBytesSequence[A: (str, bytes)]:
...Type variables can also have defaults:
- class IntDefault[T = int]:
- …
However, if desired, reusable type variables can also be constructed manually, like so:
T = TypeVar('T') # Can be anything
S = TypeVar('S', bound=str) # Can be any subtype of str
A = TypeVar('A', str, bytes) # Must be exactly str or bytes
D = TypeVar('D', default=int) # Defaults to intType variables exist primarily for the benefit of static type checkers. They serve as the parameters for generic types as well as for generic function and type alias definitions.
The variance of type variables is inferred by type checkers when they
are created through the type parameter syntax and when
infer_variance=True is passed. Manually created type variables may
be explicitly marked covariant or contravariant by passing
covariant=True or contravariant=True. By default, manually
created type variables are invariant. See PEP 484 and PEP 695 for more
details.
has_default()
has_default()
abs()
max.experimental.functional.abs(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
acos()
max.experimental.functional.acos(x)
Computes the arccosine (inverse cosine) of the input tensor.
Returns values in the range [0, π] for inputs in [-1, 1].
Creates a new op node to compute the elementwise arccosine of a symbolic tensor and adds it to the graph, returning the symbolic result.
def acos_graph():
input_type = TensorType(dtype=DType.float32, shape=(3,), device=DeviceRef.CPU())
with Graph("acos_graph", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.acos(x)
graph.output(out)-
Parameters:
-
x (TensorValue) – Input tensor with values in [-1, 1]. If values are outside this domain, they will be clamped to the valid range.
-
Returns:
-
- the same dtype as the input
- the same shape as the input
-
Return type:
-
Arccosine of the input in radians [0, π]. The result will have
-
Raises:
-
- Error – If the symbol doesn’t represent a tensor value.
- Error – If the input is not a floating-point dtype.
add()
max.experimental.functional.add(lhs, rhs)
allgather()
max.experimental.functional.allgather(t, tensor_axis=0, mesh_axis=0)
All-gather: Sharded → Replicated along mesh_axis.
allreduce_sum()
max.experimental.functional.allreduce_sum(t, mesh_axis=0)
All-reduce sum: Partial → Replicated along mesh_axis.
arange()
max.experimental.functional.arange(start, stop, step=1, out_dim=None, *, dtype=None, device=None)
Create a 1-D tensor with values from start to stop (exclusive) by step.
-
Parameters:
-
- start (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- stop (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- step (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- out_dim (int | str | Dim | integer[Any] | TypedAttr | None)
- dtype (DType | None)
- device (Device | DeviceMapping | DeviceRef | None)
-
Return type:
argmax()
max.experimental.functional.argmax(x, axis=-1)
Returns the indices of the maximum values along an axis.
When axis is None, flattens to 1-D first.
Distributed via SPMD. See max.graph.ops.argmax() for details.
argmin()
max.experimental.functional.argmin(x, axis=-1)
Returns the indices of the minimum values along an axis.
When axis is None, flattens to 1-D first.
Distributed via SPMD. See max.graph.ops.argmin() for details.
argsort()
max.experimental.functional.argsort(x, ascending=True)
Returns the indices that would sort a tensor.
This function returns the indices that would sort the input tensor along
its first dimension. The returned indices are of type int64.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – Input tensor to be sorted.
- ascending (bool) – If True (default), sort in ascending order. If False, sort in descending order.
-
Returns:
-
A tensor of indices of the same shape as the input tensor.
-
Return type:
as_interleaved_complex()
max.experimental.functional.as_interleaved_complex(x)
Reshapes the input symbolic tensor as complex from alternating (real, imag).
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor representing complex numbers as alternating pairs of (real, imag) real-valued numbers. Its last dimension must have an even size.
-
Returns:
-
A symbolic tensor representing the complex-valued tensor, but with the values pulled out as complex numbers. The result has the same dimensions for all dimensions except the last dimension, which is halved, and then a final dimension of size 2 representing the complex value.
-
Return type:
atanh()
max.experimental.functional.atanh(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
avg_pool2d()
max.experimental.functional.avg_pool2d(input, kernel_size, stride=1, dilation=1, padding=0, ceil_mode=False, count_boundary=True)
Perform a 2D average pooling operation on the input tensor.
Applies a 2D average pooling operation to the input tensor with layout
[N, H, W, C]. The pooling operation slides a window of size
kernel_size over the spatial dimensions and computes the average
value within each window.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape
[N, H, W, C]. - kernel_size (tuple[int | str | Dim | integer[Any] | TypedAttr, int | str | Dim | integer[Any] | TypedAttr]) – The height and width of the sliding window.
- stride (int | tuple[int, int]) – The stride of the sliding window. Can be a single integer
applied to both spatial dimensions or a tuple
(stride_h, stride_w). Defaults to 1. - dilation (int | tuple[int, int]) – The spacing between kernel elements. Can be a single
integer or a tuple
(dilation_h, dilation_w). Defaults to 1. - padding (int | tuple[int, int]) – Zero-padding added to both sides of each spatial dimension.
Can be a single integer or a tuple
(pad_h, pad_w). Defaults to 0. - ceil_mode (bool) – If
True, uses ceil instead of floor when computing the output spatial shape. Defaults toFalse. - count_boundary (bool) – If
True, includes padding elements in the divisor when computing the average. Defaults toTrue.
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape
-
Returns:
-
A symbolic tensor with the average pooling applied, with shape
[N, H_out, W_out, C]. -
Return type:
band_part()
max.experimental.functional.band_part(x, num_lower=None, num_upper=None, exclude=False)
Masks out everything except a diagonal band of an input matrix.
Copies a tensor setting everything outside the central diagonal band of the matrices to zero, where all but the last two axes are effectively batches, and the last two axes define sub matrices.
Assumes the input has dimensions [I, J, …, M, N], then the output tensor has the same shape as the input, and the values are given by
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].with the indicator function:
in_band(m, n) = ((num_lower is None || (m - n) <= num_lower)) &&
(num_upper is None || (n - m) <= num_upper))-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to mask.
- num_lower (int | None) – The number of diagonal bands to include below the central diagonal. If None, include the entire lower triangle.
- num_upper (int | None) – The number of diagonal bands to include above the central diagonal. If None, include the entire upper triangle.
- exclude (bool) – If true, invert the selection of elements to mask. Elements in the band are set to zero.
-
Returns:
-
A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor.
-
Raises:
-
ValueError – If the input tensor rank is less than 2, or if num_lower/num_upper are out of bounds for statically known dimensions.
-
Return type:
bottom_k()
max.experimental.functional.bottom_k(input, k, axis=-1)
Returns tensor with only the bottom K values along given axis.
-
Parameters:
-
Returns:
-
Bottom K values (ascending), Bottom K indices.
-
Return type:
broadcast_to()
max.experimental.functional.broadcast_to(x, shape, out_dims=None)
Broadcasts a symbolic tensor.
Broadcasts the input tensor to the specified shape. Dimensions in the input must be one or match the target dimension.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – The input symbolic tensor to broadcast. This tensor may not contain any dynamic dimensions.
- shape (TensorValue | Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The new shape as a list of dimensions. Dynamic dimensions are not allowed.
- out_dims (Iterable[int | str | Dim | integer[Any] | TypedAttr] | None) – Output dims used only for tensor-valued shape.
-
Returns:
-
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as
shape. -
Raises:
-
ValueError – if a tensor-valued shape is passed without out_dims.
-
Return type:
buffer_store()
max.experimental.functional.buffer_store(destination, source)
Sets a tensor buffer to new values. Distributed via SPMD.
See max.graph.ops.buffer_store() for details.
buffer_store_slice()
max.experimental.functional.buffer_store_slice(destination, source, indices)
Sets a slice of a tensor buffer to new values. Distributed via SPMD.
See max.graph.ops.buffer_store_slice() for details.
cast()
max.experimental.functional.cast(x, dtype)
Casts a symbolic tensor to a different data type.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – The input tensor to cast.
- dtype (DType) – The target dtype to which the tensor is cast.
-
Returns:
-
A new symbolic tensor with the same shape as the input and the specified dtype.
-
Return type:
chunk()
max.experimental.functional.chunk(x, chunks, axis=0)
Chunk the tensor into an exact number of chunks along the specified dim.
Example:
>>> a = TensorValue([1, 2, 3, 4, 5])
>>> chunk(a, 2, 0)
[TensorValue([1, 2]), TensorValue([3, 4])]-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The tensor to chunk.
- chunks (int) – The number of chunks to split the tensor into. chunks must statically evenly divide x.shape[axis].
- axis (int) – The axis to split the tensor along.
-
Returns:
-
A list of chunks tensors.
-
Return type:
clamp()
max.experimental.functional.clamp(x, lower_bound, upper_bound)
Clamps tensor values to a specified range.
Returns max(min(x, upper_bound), lower_bound).
-
Parameters:
-
Return type:
clip()
max.experimental.functional.clip(x, lower_bound, upper_bound)
Clamps tensor values to a specified range.
Returns max(min(x, upper_bound), lower_bound).
-
Parameters:
-
Return type:
complex_mul()
max.experimental.functional.complex_mul(lhs, rhs)
Multiply two complex valued tensors.
Complex numbers are represented as a 2-dimensional vector in the last dimension.
-
Parameters:
-
- lhs (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A complex number valued symbolic tensor.
- rhs (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A complex number valued symbolic tensor.
-
Returns:
-
The result of multiplying the input values as a complex number valued symbolic tensor.
-
Return type:
concat()
max.experimental.functional.concat(original_vals, axis=0)
Concatenates a list of symbolic tensors along an axis.
Joins multiple tensors along a specified dimension. This operation requires the functional API since it operates on multiple tensors. All input tensors must have the same rank and the same size in all dimensions except the concatenation axis.
import max.experimental.functional as F
from max.experimental.tensor import Tensor
# Create two 2x2 matrices
a = Tensor.constant([[1, 2], [3, 4]])
b = Tensor.constant([[5, 6], [7, 8]])
# Concatenate along axis 0 (rows) - stacks vertically
vertical = F.concat([a, b], axis=0)
print(f"Concatenated along axis 0: {vertical.shape}")
# Output: Concatenated along axis 0: [Dim(4), Dim(2)]
print(vertical)
# [[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]]
# Concatenate along axis 1 (columns) - joins horizontally
horizontal = F.concat([a, b], axis=1)
print(f"Concatenated along axis 1: {horizontal.shape}")
# Output: Concatenated along axis 1: [Dim(2), Dim(4)]
print(horizontal)
# [[1, 2, 5, 6],
# [3, 4, 7, 8]]-
Parameters:
-
- original_vals (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – The list of symbolic tensor values to concatenate. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension other than
axis. - axis (int) – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape. For instance,
concat(vs, -1)will concatenate along the last dimension.
- original_vals (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – The list of symbolic tensor values to concatenate. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension other than
-
Returns:
-
A new symbolic tensor representing the concatenation result. It will have the same rank as each input tensor, and its dimensions will be the same as each input tensor’s for each dimension other than axis, which will have size equal to the sum of all tensor’s size for that dimension.
-
Return type:
cond()
max.experimental.functional.cond(pred, out_types, then_fn, else_fn)
ops.cond requires a CPU predicate — inserts a transfer when needed.
-
Parameters:
-
- pred (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- out_types (Iterable[Type[Any]] | None)
- then_fn (Callable[[], Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray] | Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None])
- else_fn (Callable[[], Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray] | Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None])
-
Return type:
constant()
max.experimental.functional.constant(value, dtype=None, device=None)
Create a constant tensor from a scalar, nested list, or DLPack array.
For DLPack arrays, the array’s own dtype is preserved when dtype is
None (matching ops.constant semantics). For Python scalars and
nested lists, dtype defaults to float32 on CPU / bfloat16 on
accelerators.
Inside a realization context, emits ops.constant per device.
constant_external()
max.experimental.functional.constant_external(name, type, device=None)
Create a constant tensor from external (weight) data.
External constants are loaded at graph compile time. Supports distributed placement via DeviceMapping.
-
Parameters:
-
- name (str)
- type (TensorType)
- device (Device | DeviceMapping | DeviceRef | None)
-
Return type:
conv2d()
max.experimental.functional.conv2d(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), groups=1, bias=None, input_layout=ConvInputLayout.NHWC, filter_layout=FilterLayout.RSCF)
Computes the 2-D convolution product of the input with the given filter, bias, strides, dilations, paddings, and groups.
The op supports 2-D convolution, with the following layout assumptions:
- input x has NHWC layout, i.e., (batch_size, height, width, in_channels)
- filter has layout RSCF, i.e., (height, width, in_channels / num_groups, out_channels)
- bias has shape (out_channels,)
The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2-D convolution, dim1 here represents H and dim2 represents W. In Python like syntax, padding a 2x3 spatial input with [0, 1, 2, 1] would yield:
input = [
[1, 2, 3],
[4, 5, 6]
]
# Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
# Shape is 3x6This op currently only supports strides and padding on the input.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An NHWC input tensor to perform the convolution upon.
- filter (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The convolution filter in RSCF layout: (height, width, in_channels / num_groups, out_channels).
- stride (tuple[int, int]) – The stride of the convolution operation.
- dilation (tuple[int, int]) – The spacing between the kernel points.
- padding (tuple[int, int, int, int]) – The amount of padding applied to the input.
- groups (int) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups.
- bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Optional 1-D bias of shape (out_channels,).
- input_layout (ConvInputLayout) – Layout of the input tensor (default NHWC).
- filter_layout (FilterLayout) – Layout of the filter tensor (default RSCF).
-
Returns:
-
A symbolic tensor value with the convolution applied.
-
Return type:
conv2d_transpose()
max.experimental.functional.conv2d_transpose(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), output_paddings=(0, 0), bias=None, input_layout=ConvInputLayout.NHWC, filter_layout=FilterLayout.RSCF)
Computes the 2-D deconvolution of the input with the given filter, strides, dilations, paddings, and groups.
The op supports the transpose (gradient) of convolution, with the following layout assumptions: (note the out_channel is w.r.t. the original convolution)
- input x has NHWC layout, i.e., (batch_size, height, width, in_channels)
- filter has layout RSCF, i.e., (kernel_height, kernel_width, out_channels, in_channels)
- bias has shape (out_channels,)
The padding values are expected to take the form in the form [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]].
This op effectively computes the gradient of a convolution with respect to its input (as if the original convolution operation had the same filter and hyperparameters as this op). A visualization of the computation can be found in https://d2l.ai/chapter_computer-vision/transposed-conv.html.
The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2D ConvTranspose, dim1 here represents H_out and dim2 represents W_out. In python like syntax, padding a 2x4 spatial output with [0, 1, 2, 1] would yield:
output = [
[1, 2, 3, 4],
[5, 6, 7, 8]
]
# Shape is 2x4
padded_input = [
[3],
]
# Shape is 1x1-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An NHWC input tensor to perform the deconvolution upon.
- filter (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The convolution filter in RSCF layout: (height, width, out_channels, in_channels).
- stride (tuple[int, int]) – The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 0.
- dilation (tuple[int, int]) – The spacing between the kernel points.
- padding (tuple[int, int, int, int]) – The amount of padding applied to the input.
- output_paddings (tuple[int, int]) – this argument is meant to resolve the ambiguity of multiple potential output shapes when any stride is greater than 1. Basically, we’ll add output_paddings[i] number of zeros at the end of output’s ith axis. We only support output_paddings = 0.
- bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Tensor of shape (out_channels,).
- input_layout (ConvInputLayout) – Layout of the input tensor (default NHWC).
- filter_layout (FilterLayout) – Layout of the filter tensor (default RSCF).
-
Returns:
-
A symbolic tensor value with the convolution applied.
-
Return type:
conv3d()
max.experimental.functional.conv3d(x, filter, stride=(1, 1, 1), dilation=(1, 1, 1), padding=(0, 0, 0, 0, 0, 0), groups=1, bias=None, input_layout=ConvInputLayout.NHWC, filter_layout=FilterLayout.QRSCF)
Computes the 3-D convolution product of the input with the given filter, strides, dilations, paddings, and groups.
The op supports 3-D convolution, with the following layout assumptions:
- input has NDHWC layout, i.e., (batch_size, depth, height, width, in_channels)
- filter has layout RSCF, i.e., (depth, height, width, in_channels / num_groups, out_channels)
The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 3-D convolution, dim1 here represents D, dim2 represents H and dim3 represents W. In Python like syntax, padding a 2x3 spatial input with [0, 1, 2, 1] would yield:
input = [
[1, 2, 3],
[4, 5, 6]
]
# Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
# Shape is 3x6This op currently only supports strides and padding on the input.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An NDHWC input tensor to perform the convolution upon.
- filter (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The convolution filter in RSCF layout: (depth, height, width, in_channels / num_groups, out_channels).
- stride (tuple[int, int, int]) – The stride of the convolution operation.
- dilation (tuple[int, int, int]) – The spacing between the kernel points.
- padding (tuple[int, int, int, int, int, int]) – The amount of padding applied to the input.
- groups (int) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups.
- bias (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Optional 1-D bias of shape (out_channels,).
- input_layout (ConvInputLayout) – Layout of the input tensor (default NDHWC).
- filter_layout (FilterLayout) – Layout of the filter tensor (default QRSCF).
-
Returns:
-
A symbolic tensor value with the convolution applied. Output shape = (batch_size, depth, height, width, out_channels).
-
Return type:
cos()
max.experimental.functional.cos(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
cumsum()
max.experimental.functional.cumsum(x, axis=-1, exclusive=False, reverse=False)
Computes the cumulative sum of the input tensor along the given axis.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to sum over.
- axis (int) – The axis along which to compute the sum. If negative, indexes from the last dimension. For example, a value of -1 will compute the sum along the last dimension.
- exclusive (bool) – If set, start at 0 and exclude the final element. Otherwise, start with the first element. Said another way, cumsum computes [sum(x[…, :i, …]) for i in range(x.shape[axis])]. If exclusive is set, the bounds are instead range(1, x.shape[axis]).
- reverse (bool) – If set, start from the end. In other words, the first element will be the total sum, with each element following counting downwards; or [sum(x[…, i:, …]) for i in range(x.shape[axis])].
-
Returns:
-
A symbolic tensor representing the result of the cumsum operation. The tensor will have the same type as the input tensor. The computed values will be the cumulative sum of the values along the given axis, according to the specified parameters:
- if exclusive is set, the first value will be 0, and the last value will be excluded from the sum
- if reverse is set, the sum will be computed starting at the back of the axis back to the front, rather than front-to-back
-
Raises:
-
ValueError – If
xis on a non-CPU device andstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
custom()
max.experimental.functional.custom(name, device, values, out_types, parameters=None, custom_extensions=None)
Apply a custom op with optional custom extension loading.
dequantize()
max.experimental.functional.dequantize(encoding, quantized)
Dequantizes a quantized tensor to floating point.
NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.
-
Parameters:
-
- encoding (QuantizationEncoding) – The quantization encoding to use.
- quantized (TensorValue) – The quantized tensor to dequantize.
-
Returns:
-
The dequantized result (a floating point tensor).
-
Return type:
div()
max.experimental.functional.div(lhs, rhs)
elementwise_max()
max.experimental.functional.elementwise_max(lhs, rhs)
elementwise_min()
max.experimental.functional.elementwise_min(lhs, rhs)
ensure_context()
max.experimental.functional.ensure_context()
Ensure a realization context exists for Tensor <-> TensorValue conversion.
Three execution contexts are supported:
- Eager (
EagerRealizationContext) — created automatically when no context is active and we are not inside aGraph. On exit,realize_all()is called so that all symbolic graph ops executed within the block are compiled and realized into concrete tensors. - Lazy (
LazyRealizationContext) — set externally viawith lazy():. When already active this function re-uses it; tensors remain unrealized until explicitly awaited. - Graph (
GraphRealizationContext) — created automatically when we are inside aGraph.currentcontext. Tensors stay symbolic.
If a context of any kind already exists, it is re-used as-is.
-
Return type:
-
Generator[None]
equal()
max.experimental.functional.equal(lhs, rhs)
erf()
max.experimental.functional.erf(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
exp()
max.experimental.functional.exp(x)
Computes the elementwise exp (exponential) function of a symbolic tensor.
Creates a new op node to compute the elementwise exponential function of a symbolic tensor and adds it to the graph, returning the symbolic result. The exp function is fundamental in neural networks, used in attention mechanisms, activation functions, and probability distributions.
import max.experimental.functional as F
from max.experimental.tensor import Tensor
# Create input tensor
x = Tensor.constant([0.0, 1.0, 2.0])
# Compute exponential
result = F.exp(x)
print(result)
# Output: [1.0, 2.718..., 7.389...]
# (e^0 = 1, e^1 ≈ 2.718, e^2 ≈ 7.389)exp is defined as exp(x) = e^x, where e is Euler’s number.
-
Parameters:
-
- value – The symbolic tensor to use as the input to the exp function computation.
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Returns:
-
A new symbolic tensor value representing the output of the exp value computation.
-
Raises:
-
Error – If the symbol doesn’t represent a tensor value.
-
Return type:
flatten()
max.experimental.functional.flatten(x, start_dim=0, end_dim=-1)
Flattens the specified dims of a symbolic tensor.
The number and order of the elements in the tensor is unchanged.
All dimensions from start_dim to end_dim (inclusive) are merged
into a single output dim.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to flatten.
- start_dim (int) – The first dimension to flatten. Supports negative indexing. Defaults to 0.
- end_dim (int) – The last dimension to flatten (inclusive). Supports negative indexing. Defaults to -1.
-
Returns:
-
A symbolic tensor with the same elements as the input, but with dimensions
start_dimthroughend_dimmerged into one. -
Raises:
-
- IndexError – If
start_dimorend_dimare out of range. - ValueError – If
start_dimcomes afterend_dim.
- IndexError – If
-
Return type:
floor()
max.experimental.functional.floor(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
fold()
max.experimental.functional.fold(input, output_size, kernel_size, stride=1, dilation=1, padding=0)
Combines an array of sliding blocks into a larger containing tensor.
The input tensor must have shape (N, C * kernel_sizes, L) where N is
the batch dimension, C is the number of channels, kernel_sizes is
the product of the kernel sizes, and L is the number of local blocks.
The resulting output tensor will have shape
(N, C, output_shape[0], output_shape[1]).
L, the number of blocks, must be equivalent to:
prod((output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1)
where d is over all spatial dimensions.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The 3D tensor to fold with shape
(N, C * kernel sizes, L). - output_size (tuple[int | str | Dim | integer[Any] | TypedAttr, int | str | Dim | integer[Any] | TypedAttr]) – Spatial dimensions of the output tensor. Must be a tuple of two ints.
- kernel_size (tuple[int | str | Dim | integer[Any] | TypedAttr, int | str | Dim | integer[Any] | TypedAttr]) – The size of the sliding blocks. Must be a tuple of two ints.
- stride (int | tuple[int, int]) – The stride of the sliding blocks in the input dimension (can be an int or a tuple of two ints).
- dilation (int | tuple[int, int]) – The spacing between the kernel elements. (can be an int or a tuple of two ints).
- padding (int | tuple[int, int]) – 0-paddings to be added on both sides of the inputs. (can be an int or a tuple of two ints).
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The 3D tensor to fold with shape
-
Returns:
-
The folded 4D tensor with shape
(N, C, output_shape[0], output_shape[1]). -
Return type:
full()
max.experimental.functional.full(shape, value, *, dtype=None, device=None)
Create a tensor filled with value, optionally distributed across devices.
full_like()
max.experimental.functional.full_like(like, value)
Create a tensor filled with value, matching the shape and dtype of like.
functional()
max.experimental.functional.functional(graph_op, rule=None)
Wrap a graph op to work with Tensor inputs and optional SPMD sharding.
Non-distributed path: calls graph_op directly inside an
ensure_context() block. Graph ops accept TensorValueLike
(which Tensor satisfies), so args pass through unchanged.
Results are converted back to Tensors.
Distributed path (when any input is distributed and a rule is
provided): extracts TensorLayouts, calls the rule, transfers args
to match, then dispatches per-shard via spmd_dispatch.
Uses functools.wraps so the wrapped function inherits the graph
op’s name and docstring.
gather()
max.experimental.functional.gather(input, indices, axis)
Selects elements out of an input tensor by index.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to select elements from.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of index values to use for selection.
- axis (int) – The dimension which
indicesindexes frominput. If negative, indexes relative to the end of the input tensor. For instance,gather(input, indices, axis=-1)will index against the last dimension ofinput.
-
Returns:
-
A new symbolic tensor representing the result of the gather operation.
-
Return type:
gather_nd()
max.experimental.functional.gather_nd(input, indices, batch_dims=0)
Selects elements out of an input tensor by N-dimensional index.
This operation performs N-dimensional indexing into input using indices.
Unlike gather(), which indexes along a single axis, gather_nd() allows
indexing along multiple dimensions simultaneously.
input_shape = ["a", "b", "c", "d", "e"]
indices_shape = ["a", "f", 3]
input_type = TensorType(DType.bfloat16, input_shape)
indices_type = TensorType(DType.int32, indices_shape)
with Graph("gather_nd", input_types=[input_type, indices_type]) as graph:
input, indices = graph.inputs
gathered = ops.gather_nd(input, indices, batch_dims=1)
print(gathered.type)
# Output: TensorType(dtype=DType.bfloat16, shape=["a", "f", "e"])In this example:
batch_dimsis 1, so there’s 1 shared dimension at the beginning.indiceshas an additional dimension “f” which becomes part of the output.- The last dimension of
indicesis the index vector; values in this vector are interpreted to be indices into “b”, “c”, and “d”. - Since
batch_dims (1) + index size (3) < input.rank (5), the remaining dimensions (in this case “e”) are sliced into the output as features.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to select elements from.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of index values to use for selection.
The last dimension of this tensor must be static. This dimension
will be used to index or slice into
inputimmediately followingbatch_dimsinitial dimensions. The size of this index dimension is the number of dimensions it specifies. - batch_dims (int) – The number of leading batch dimensions shared by
inputandindices; 0 by default.inputandindicesmust exactly match up to their firstbatch_dimsdimensions. This function does not broadcast.
-
Returns:
-
A new symbolic tensor representing the result of the gather operation. The output will have the same dtype as
input, and will have shape depending on the inputs, in this order:input.shape[:batch_dims]– The “broadcast” dimensions (though note that this function does not broadcast). These dimensions must be identical betweeninputandindices.indices.shape[batch_dims:-1]– The “gather” dimensions; this allows multi-dimensional tensors of indices. The last dimension is the index vector.input.shape[batch_dims + indices.shape[-1]:]– The “slice” dimensions. Ifbatch_dims<input.rank - indices.shape[-1](again, this last is the index vector), then any following dimensions of the inputs are taken entirely as though slicing.
-
Return type:
gaussian()
max.experimental.functional.gaussian(shape=(), mean=0.0, std=1.0, *, dtype=None, device=None)
Sample from a Gaussian (normal) distribution with given mean and std.
gaussian_like()
max.experimental.functional.gaussian_like(like, mean=0.0, std=1.0)
Sample Gaussian values matching the shape and dtype of like.
-
Parameters:
-
- like (Tensor | TensorType | DistributedTensorType)
- mean (float)
- std (float)
-
Return type:
gelu()
max.experimental.functional.gelu(x, approximate='none')
Computes the elementwise gelu of a symbolic tensor.
Creates a new op node to compute the elementwise gelu of a symbolic tensor and adds it to the graph, returning the symbolic result.
For approximate == "none", the exact gelu function is computed.
For approximate == "tanh", the approximation:
is used.
For approximate == "quick", the approximation:
is used.
-
Parameters:
-
- x (TensorValue) – The symbolic tensor to use as the input to the gelu computation.
- approximate (str) – One of
none,tanh, orquick.
-
Returns:
-
A new symbolic tensor value representing the output of the gelu computation.
-
Raises:
-
- Error – If the symbol doesn’t represent a tensor value.
- ValueError – If the approximation method is invalid.
greater()
max.experimental.functional.greater(lhs, rhs)
greater_equal()
max.experimental.functional.greater_equal(lhs, rhs)
group_norm()
max.experimental.functional.group_norm(input, gamma, beta, num_groups, epsilon)
Performs group normalization.
Divides channels into groups and computes normalization statistics within each group. Useful for small batch sizes where batch normalization is unstable.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor of shape
[N, C, ...]to normalize. - gamma (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The scale parameter of shape
[C]. - beta (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The bias parameter of shape
[C]. - num_groups (int) – The number of groups to divide the channels into.
- epsilon (float) – A small value added to the denominator for numerical stability.
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor of shape
-
Returns:
-
A normalized tensor with the same shape as
input. -
Raises:
-
ValueError – If the input tensor has fewer than 2 dimensions.
-
Return type:
hann_window()
max.experimental.functional.hann_window(window_length, *, periodic=True, dtype=None, device=None)
Create a Hann window of the given window_length.
in_graph_context()
max.experimental.functional.in_graph_context()
Return True when executing inside a Graph.current context.
-
Return type:
inplace_custom()
max.experimental.functional.inplace_custom(name, device, values, out_types=None, parameters=None, custom_extensions=None)
Apply an in-place custom op with optional custom extension loading.
irfft()
max.experimental.functional.irfft(input_tensor, n=None, axis=-1, normalization=Normalization.BACKWARD, input_is_complex=False, buffer_size_mb=512)
Compute the inverse real FFT of the input tensor.
-
Parameters:
-
- input_tensor (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue) – The input tensor to compute the inverse real FFT of.
- n (int | None) – The size of the output tensor. Must be an int, and cannot be a symbolic Buffer. The input tensor will be padded or truncated to n // 2 + 1 along the specified axis.
- axis (int) – The axis to compute the inverse real FFT of.
- normalization (Normalization | str) – The normalization to apply to the output tensor. Can be “backward”, “ortho”, or “forward”. When “backward”, the output is divided by n. When “ortho”, the output is divided by sqrt(n). When “forward”, no normalization is applied.
- input_is_complex (bool) – Whether the input tensor is already interleaved complex. The last dimension of the input tensor must be 2, and is excluded from the dimension referred to by axis.
- buffer_size_mb (int) – The estimated size of a persistent buffer to use for storage of intermediate results. Needs to be the same across multiple calls to irfft within the same graph. Otherwise, multiple buffers will be allocated.
-
Returns:
-
The inverse real FFT of the input tensor. The shape of the output tensor is the same as the shape of the input tensor, except for the axis that the inverse real FFT is computed over, which is replaced by n.
is_inf()
max.experimental.functional.is_inf(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
is_nan()
max.experimental.functional.is_nan(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
layer_norm()
max.experimental.functional.layer_norm(input, gamma, beta, epsilon)
Performs layer normalization.
-
Parameters:
-
- input (TensorValue) – The input tensor to normalize.
- gamma (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The gamma parameter of the normalization.
- beta (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The beta parameter of the normalization.
- epsilon (float) – The epsilon parameter of the normalization.
-
Returns:
-
A graph tensor value with the normalization applied.
-
Raises:
-
- ValueError – If gamma size doesn’t match the last dimension of input.
- ValueError – If beta size doesn’t match the last dimension of input.
- ValueError – If epsilon is not positive.
-
Return type:
lazy()
max.experimental.functional.lazy()
Context manager for lazy (deferred) tensor evaluation.
Within this context, tensor operations are recorded but not executed.
Tensors remain unrealized until explicitly awaited via
await tensor.realize or until their values are needed.
Example:
from max.experimental import functional as F
with F.lazy():
a = Tensor.zeros([5, 5])
b = a + 1
# b is unrealized — no compilation has happened yet
await b.realize-
Return type:
-
Generator[None]
log()
max.experimental.functional.log(x)
Computes the elementwise natural logarithm of a symbolic tensor.
Creates a new op node to compute the elementwise natural logarithm of a symbolic tensor and adds it to the graph, returning the symbolic result. The natural logarithm is used in loss functions, normalization, and probability calculations in machine learning.
import max.experimental.functional as F
from max.experimental.tensor import Tensor
# Create input tensor (positive values only)
x = Tensor.constant([1.0, 2.718, 7.389, 20.0])
# Compute natural logarithm
result = F.log(x)
print(result)
# Output: [0.0, 1.0, 2.0, 2.996...]
# (log(1) = 0, log(e) = 1, log(e^2) = 2)The natural logarithm function log is defined as the inverse of the
exponential function exp(). In other words, it computes the value y in
the equation x = e^y where e is Euler’s number.
log(x) is undefined for x <= 0 for real numbers. Complex numbers
are currently unsupported.
-
Parameters:
-
- value – The symbolic tensor to use as the input to the natural logarithm computation.
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Returns:
-
A new symbolic tensor value representing the output of the natural logarithm value computation.
-
Raises:
-
Error – If the symbol doesn’t represent a tensor value.
-
Return type:
log1p()
max.experimental.functional.log1p(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
logical_and()
max.experimental.functional.logical_and(lhs, rhs)
logical_not()
max.experimental.functional.logical_not(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
logical_or()
max.experimental.functional.logical_or(lhs, rhs)
logical_xor()
max.experimental.functional.logical_xor(lhs, rhs)
logsoftmax()
max.experimental.functional.logsoftmax(value, axis=-1)
-
Parameters:
-
Return type:
masked_scatter()
max.experimental.functional.masked_scatter(input, mask, updates, out_dim)
Creates a new symbolic tensor where the updates are written to input where mask is true.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to write elements to.
- mask (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of boolean values to update.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of elements to write to input.
- out_dim (int | str | Dim | integer[Any] | TypedAttr) – The new data-dependent dimension.
-
Returns:
-
A new symbolic tensor representing the result of the masked_scatter operation.
-
Return type:
matmul()
max.experimental.functional.matmul(lhs, rhs)
Computes the matrix multiplication of two tensor graph values.
Performs general matrix multiplication with broadcasting. Matrix multiplication is fundamental to neural networks, used for linear transformations, attention mechanisms, and fully connected layers.
from max.experimental.tensor import Tensor
# Create two 2x2 matrices
x = Tensor.constant([[1.0, 2.0], [3.0, 4.0]]) # Shape: (2, 2)
w = Tensor.constant([[5.0, 6.0], [7.0, 8.0]]) # Shape: (2, 2)
# Matrix multiply using @ operator (uses matmul internally)
result = x @ w
print("Matrix multiplication result:")
print(result)
# Output: [[19.0, 22.0],
# [43.0, 50.0]]
# Computed as: result[i,j] = sum(x[i,k] * w[k,j])
# Can also call directly via functional API
import max.experimental.functional as F
result2 = F.matmul(x, w)
# Same result as x @ wIf the lhs is 1D, it will be reshaped to 1xD.
If the rhs is 1D, it will be reshaped to Dx1.
In both cases, the additional 1 dimensions will be removed from the
output shape.
For the multiplication, the innermost (rightmost) 2 dimensions are treated
as a matrix.
The lhs matrix will have the shape MxK.
The rhs matrix will have the shape KxN.
The output will have the shape MxN
The K dimensions must be equivalent in both matrices.
The remaining outer dimensions will be broadcasted.
-
Parameters:
-
- lhs (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The left-hand side input tensor.
- rhs (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The right-hand side input tensor.
- location – An optional location for a more specific error message.
-
Returns:
-
A tensor graph value representing the matrix product of
lhsandrhs. For 2D inputs, the output shape is(M, N)wherelhsis(M, K)andrhsis(K, N). For higher-dimensional inputs, batch dimensions are preserved and the operation is applied to the last two dimensions of each input. -
Return type:
max()
max.experimental.functional.max(x, y=None, /, axis=-1)
max_pool2d()
max.experimental.functional.max_pool2d(input, kernel_size, stride=1, dilation=1, padding=0, ceil_mode=False)
Perform a 2D max pooling operation on the input tensor.
Applies a 2D max pooling operation to the input tensor with layout
[N, H, W, C]. The pooling operation slides a window of size
kernel_size over the spatial dimensions and selects the maximum
value within each window.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape
[N, H, W, C]. - kernel_size (tuple[int | str | Dim | integer[Any] | TypedAttr, int | str | Dim | integer[Any] | TypedAttr]) – The height and width of the sliding window.
- stride (int | tuple[int, int]) – The stride of the sliding window. Can be a single integer
applied to both spatial dimensions or a tuple
(stride_h, stride_w). Defaults to 1. - dilation (int | tuple[int, int]) – The spacing between kernel elements. Can be a single
integer or a tuple
(dilation_h, dilation_w). Defaults to 1. - padding (int | tuple[int, int]) – Zero-padding added to both sides of each spatial dimension.
Can be a single integer or a tuple
(pad_h, pad_w). Defaults to 0. - ceil_mode (bool) – If
True, uses ceil instead of floor when computing the output spatial shape. Defaults toFalse.
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape
-
Returns:
-
A symbolic tensor with the max pooling applied, with shape
[N, H_out, W_out, C]. -
Return type:
mean()
max.experimental.functional.mean(x, axis=-1)
min()
max.experimental.functional.min(x, y=None, /, axis=-1)
mod()
max.experimental.functional.mod(lhs, rhs)
mul()
max.experimental.functional.mul(lhs, rhs)
negate()
max.experimental.functional.negate(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
non_maximum_suppression()
max.experimental.functional.non_maximum_suppression(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, out_dim='num_selected')
Filters boxes with high intersection-over-union (IoU).
Applies greedy non-maximum suppression independently per (batch, class) pair. For each pair the algorithm:
- Discards boxes whose score is at or below
score_threshold. - Sorts remaining boxes by score in descending order.
- Greedily selects boxes, suppressing any later candidate whose IoU with
an already-selected box exceeds
iou_threshold. - Stops after
max_output_boxes_per_classselections per pair.
Boxes use [y1, x1, y2, x2] corner format. Coordinates may be
normalised or absolute; the op handles both.
-
Parameters:
-
- boxes (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Input boxes tensor of shape
[batch, num_boxes, 4](float). - scores (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Per-class scores of shape
[batch, num_classes, num_boxes](float, same dtype asboxes). - max_output_boxes_per_class (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Scalar int64 tensor — maximum number of boxes to select per (batch, class) pair.
- iou_threshold (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Scalar float tensor — IoU suppression threshold.
- score_threshold (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Scalar float tensor — minimum score to consider.
- out_dim (str) – Name for the dynamic output dimension (number of selected
boxes). Defaults to
"num_selected".
- boxes (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Input boxes tensor of shape
-
Returns:
-
An int64 tensor of shape
[out_dim, 3]where each row is[batch_index, class_index, box_index]. -
Return type:
nonzero()
max.experimental.functional.nonzero(x, out_dim)
Returns the indices of all nozero elements in a tensor.
Returns a tensor of indices of the nonzero values in the given tensor. The
return value is a 2D tensor of shape [out_dim x rank_in], where
out_dim is the number of nonzero elements in the input tensor, and
rank_in is the rank of the input tensor. Indices are generated in
row-major order.
-
Parameters:
-
Returns:
-
A symbolic tensor of indices
-
Raises:
-
ValueError – If
xis scalar, or ifxis on a non-CPU device andstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
normal()
max.experimental.functional.normal(shape=(), mean=0.0, std=1.0, *, dtype=None, device=None)
Sample from a Gaussian (normal) distribution with given mean and std.
normal_like()
max.experimental.functional.normal_like(like, mean=0.0, std=1.0)
Sample Gaussian values matching the shape and dtype of like.
-
Parameters:
-
- like (Tensor | TensorType | DistributedTensorType)
- mean (float)
- std (float)
-
Return type:
not_equal()
max.experimental.functional.not_equal(lhs, rhs)
ones()
max.experimental.functional.ones(shape, *, dtype=None, device=None)
Create an all-ones tensor, optionally distributed across devices.
ones_like()
max.experimental.functional.ones_like(like)
Create an all-ones tensor matching the shape and dtype of like.
-
Parameters:
-
like (Tensor | TensorType | DistributedTensorType)
-
Return type:
outer()
max.experimental.functional.outer(lhs, rhs)
Computes the outer product of two symbolic vectors.
-
Parameters:
-
- lhs (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The left side of the product. Whatever its shape, it will be flattened to a rank-1 vector.
- rhs (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The right side of the product. Whatever its shape, it will be flattened to a rank-1 vector. Must have the same number of elements as lhs.
-
Returns:
-
A symbolic tensor representing the outer product of the two input vectors. It will have rank 2, with the dimension sizes being the number of elements of lhs and rhs respectively.
-
Return type:
pad()
max.experimental.functional.pad(input, paddings, mode='constant', value=0)
Pads a tensor along every dimension.
Adds padding to the input tensor using the specified padding values and mode.
-
Parameters:
-
-
input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to pad.
-
paddings (Iterable[int]) – Sequence of padding values. For a tensor with rank N, paddings must contain 2*N non-negative integers in the order
[pad_before_dim0, pad_after_dim0, pad_before_dim1, pad_after_dim1, ...]. -
mode (Literal['constant', 'reflect', 'edge']) –
The padding mode. Supported values:
"constant"- fill padded cells withvalue."reflect"- reflect values about the content-region edges (excludes the boundary element, equivalent tonumpy.padwithmode='reflect')."edge"- repeat the nearest boundary element (equivalent tonumpy.padwithmode='edge').
-
value (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The constant fill value (only used when
mode='constant'). Defaults to 0.
-
-
Returns:
-
A symbolic tensor with the same dtype as
input, padded along each dimension according topaddings. -
Raises:
-
ValueError – If
modeis not one of the supported values, or if any padding value is negative. -
Return type:
permute()
max.experimental.functional.permute(x, dims)
Permutes all dimensions of a symbolic tensor.
-
Parameters:
-
Returns:
-
A new symbolic tensor with the dimensions permuted to match the passed in order. It has the same elements and dtype, but the order of the elements is different according to the permutation.
-
Return type:
pow()
max.experimental.functional.pow(lhs, rhs)
prod()
max.experimental.functional.prod(x, axis=-1)
qmatmul()
max.experimental.functional.qmatmul(encoding, config, lhs, *rhs)
Performs matrix multiplication between floating point and quantized tensors.
This quantizes the lhs floating point value to match the encoding of the
rhs quantized value, performs matmul, and then dequantizes the result.
Beware that, compared to a regular matmul op, this one expects the rhs
value to be transposed. For example, if the lhs shape is [32, 64], and
the quantized rhs shape is also [32, 64], then the output shape is
[32, 32].
That is, this function returns the result from:
dequantize(quantize(lhs) @ transpose(rhs))
The last two dimensions in lhs are treated as matrices and multiplied
by rhs (which must be a 2D tensor). Any remaining dimensions in lhs
are broadcast dimensions.
NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.
-
Parameters:
-
- encoding (QuantizationEncoding) – The quantization encoding to use.
- config (QuantizationConfig | None) – Optional quantization config; required for some encodings (for example, GPTQ).
- lhs (TensorValue) – The non-quantized, left-hand-side of the matmul.
- rhs (TensorValue) – The transposed and quantized right-hand-side tensor(s).
-
Returns:
-
The dequantized result (a floating point tensor).
-
Return type:
range()
max.experimental.functional.range(start, stop, step=1, out_dim=None, *, dtype=None, device=None)
Create a 1-D tensor with values from start to stop (exclusive) by step.
-
Parameters:
-
- start (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- stop (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- step (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- out_dim (int | str | Dim | integer[Any] | TypedAttr | None)
- dtype (DType | None)
- device (Device | DeviceMapping | DeviceRef | None)
-
Return type:
rebind()
max.experimental.functional.rebind(x, shape, message='', layout=None)
Rebinds a symbolic tensor to a specified set of dimensions.
This does not mutate the symbolic tensor passed in, but instead adds a
runtime assert that the input symbolic shape is equivalent to
out_dims shape. For example, if the input tensor shape has
dynamic/unknown sizes, this will assert a fixed sizes that may be required
for a subsequent operation.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to rebind.
- shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The symbolic shape to assert for
x, as a list ofDimvalues. - message (str) – The message printed if the rebind fails at runtime.
- layout (FilterLayout | None) – A layout of the weights used by some operations like conv.
-
Returns:
-
A symbolic tensor with the same elements and shape as the given tensor, but with the symbolic shape asserted to
out_dims. -
Return type:
reduce_scatter()
max.experimental.functional.reduce_scatter(t, scatter_axis=0, mesh_axis=0)
Reduce-scatter: Partial → Sharded along mesh_axis.
Decomposed into allreduce + local split until native reduce-scatter
supports sub-group calls in the same graph. A native
ops.reducescatter.sum would move half the bytes.
relu()
max.experimental.functional.relu(x)
Computes the elementwise ReLU (Rectified Linear Unit) of a symbolic tensor.
Creates a new op node to compute the elementwise ReLU of a symbolic tensor
and adds it to the graph, returning the symbolic result. ReLU is defined as
relu(x) = max(0, x), setting all negative values to zero while leaving
positive values unchanged.
ReLU is one of the most common activation functions in neural networks due to its computational efficiency and effectiveness in addressing the vanishing gradient problem.
import max.experimental.functional as F
from max.experimental.tensor import Tensor
# Create input with negative and positive values
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])
# Apply ReLU activation
result = F.relu(x)
print(result)
# Output: [[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]
# Negative values become 0, positive values unchanged-
Parameters:
-
- value – The symbolic tensor to use as the input to the relu computation.
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Returns:
-
A new symbolic tensor value representing the output of the relu value computation.
-
Raises:
-
Error – If the symbol doesn’t represent a tensor value.
-
Return type:
repeat_interleave()
max.experimental.functional.repeat_interleave(x, repeats, axis=None, out_dim=None)
Repeats elements of a tensor along the given dimension.
Modeled after torch.repeat_interleave, with the constraint that
For example, given repeats=2 and the following input:
# Input tensor with shape (2, 2)
input = TensorValue(x) # Contains [[1.0, 2.0], [3.0, 4.0]]repeat_interleave with axis=0:
# Output tensor with shape (4, 2)
output = repeat_interleave(input, repeats=2, axis=0)
# Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]repeat_interleave with axis=1:
# Output tensor with shape (2, 4)
output = repeat_interleave(input, repeats=2, axis=1)
# Contains [[1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0]]repeat_interleave with axis=None (the default):
repeat_interleave with repeats=[2, 3] and axis=0:
repeat_value = TensorValue([2, 3])
# Output tensor with shape (5, 2)
output = repeat_interleave(input, repeats=repeat_value, axis=0)
# Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]# Output tensor with shape (8,)
output = repeat_interleave(input, repeats=2) # axis = None
# Contains [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor.
- repeats (int | TensorValue) – The number of repetitions for each element.
- axis (int | None) – The dimension along which to repeat values. If axis is not specified or None (the default), flatten the input array and repeat the flattened values.
- out_dim (int | str | Dim | integer[Any] | TypedAttr | None) – Optional symbolic dimension for the output size (for graph validation).
-
Returns:
-
A symbolic tensor with the elements interleaved.
-
Raises:
-
ValueError – If
repeatsnon-positive or ifaxisis out of range. -
Return type:
reshape()
max.experimental.functional.reshape(x, shape)
Reshapes a symbolic tensor.
The number and order of the elements in the tensor is unchanged. In other words, if you were to iterate over elements in the tensor by major dimension to minor dimension, the iteration order would stay the same.
If a value of -1 is present in the shape, that dimension becomes an automatically calculated dimension collecting all unspecified dimensions. Its length becomes the number of elements in the original tensor divided by the product of elements of the reshape.
-
Parameters:
-
Returns:
-
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as
shape. -
Raises:
-
ValueError – if input and target shapes’ number of elements mismatch.
-
Return type:
resize()
max.experimental.functional.resize(input, shape, interpolation=InterpolationMode.BILINEAR)
Resize the input tensor to the given shape.
This function resizes a tensor using the specified interpolation method. The tensor is expected to have NCHW format (batch, channels, height, width).
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to resize. Must have rank 4 in NCHW format.
- shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – Desired output shape of length 4 corresponding to (N, C, H, W).
- interpolation (InterpolationMode) – Desired interpolation enum defined by
InterpolationMode. Defaults toInterpolationMode.BILINEAR.
-
Returns:
-
A resized tensor with the shape specified by the shape argument.
-
Raises:
-
ValueError – If the input doesn’t have rank 4, shape has wrong number of elements, or unsupported interpolation mode is specified.
-
Return type:
resize_bicubic()
max.experimental.functional.resize_bicubic(input, size)
Resize a tensor using bicubic interpolation.
Produces an output tensor whose dimensions are given by size using
a 4x4-pixel Catmull-Rom (a=-0.75) cubic convolution filter with
half_pixel coordinate mapping. Input must be rank-4 NCHW.
-
Parameters:
-
Returns:
-
A new symbolic tensor with shape
sizeand the same dtype asinput. -
Raises:
-
ValueError – If
inputdoesn’t have rank 4 orsizehas a different length. -
Return type:
resize_linear()
max.experimental.functional.resize_linear(input, size, coordinate_transform_mode=0, antialias=False)
Resize a tensor using linear (bilinear) interpolation.
Produces an output tensor whose spatial dimensions are given by size
using separable 1-D linear filters. The operation maps output coordinates
back to input coordinates according to coordinate_transform_mode.
-
Parameters:
-
-
input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to resize.
-
size (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – Desired output shape. Must have the same rank as
input. -
coordinate_transform_mode (int) –
How to map an output coordinate to an input coordinate. Allowed values:
0–half_pixel(default): shifts by 0.5 before scaling, consistent with most deep-learning frameworks.1–align_corners: aligns the corner pixels of input and output so that the first and last coordinates are preserved exactly.2–asymmetric: no shift; equivalent to floor-dividing coordinates by the scale factor.3–half_pixel_1D: likehalf_pixelbut only applied to the last spatial dimension.
-
antialias (bool) – When
True, applies an antialiasing filter when the output is smaller than the input (i.e. when downscaling), which reduces aliasing artifacts by widening the tent filter support by1 / scale. Has no effect when upscaling.
-
-
Returns:
-
A new symbolic tensor with shape
sizeand the same dtype asinput. -
Raises:
-
ValueError – If
coordinate_transform_modeis not 0-3, or ifsizehas a different rank thaninput. -
Return type:
resize_nearest()
max.experimental.functional.resize_nearest(input, size, coordinate_transform_mode=0, round_mode=0)
Resize a tensor using nearest-neighbor interpolation.
Produces an output tensor whose dimensions are given by size by
selecting the nearest input sample for each output coordinate.
-
Parameters:
-
-
input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to resize.
-
size (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – Desired output shape. Must have the same rank as
input. -
coordinate_transform_mode (int) –
How to map an output coordinate to an input coordinate. Allowed values:
0–half_pixel(default).1–align_corners.2–asymmetric.3–half_pixel_1D.
-
round_mode (int) –
How to round the mapped coordinate to select the nearest input sample. Allowed values:
0–HalfDown(default):ceil(x - 0.5).1–HalfUp:floor(x + 0.5).2–Floor:floor(x).3–Ceil:ceil(x).
-
-
Returns:
-
A new symbolic tensor with shape
sizeand the same dtype asinput. -
Raises:
-
ValueError – If
coordinate_transform_modeis not 0-3,round_modeis not 0-3, orsizehas a different rank thaninput. -
Return type:
rms_norm()
max.experimental.functional.rms_norm(input, weight, epsilon, weight_offset=0.0, multiply_before_cast=False)
Performs Root Mean Square layer normalization.
Computes output = input / rms(input) * weight where
rms(x) = sqrt(mean(x^2) + epsilon).
When multiply_before_cast is False (Llama-style), the input is
cast to the output dtype before multiplication by the weight. When
True (Gemma-style), the multiplication is performed before the cast.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to normalize.
- weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The weight tensor whose shape must match the last dimension
of
input. - epsilon (float) – A small value added to the denominator for numerical stability.
- weight_offset (float) – A value added to the weight before normalization.
Typically
1for Gemma-like normalization and0otherwise. - multiply_before_cast (bool) – Whether to multiply before casting to the output dtype.
-
Returns:
-
A normalized tensor with the same shape and dtype as
input. -
Raises:
-
ValueError – If weight shape doesn’t match the last dimension of input.
-
Return type:
roi_align()
max.experimental.functional.roi_align(input, rois, output_height, output_width, spatial_scale=1.0, sampling_ratio=0.0, aligned=False, mode='AVG')
Perform ROI Align pooling on the input tensor.
Extracts fixed-size feature maps from regions of interest (ROIs) in the input tensor using bilinear interpolation. The input is expected in NHWC layout.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape
[N, H, W, C]. - rois (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – Regions of interest with shape
[M, 5], where each row is[batch_index, x1, y1, x2, y2]. - output_height (int) – Height of each output feature map.
- output_width (int) – Width of each output feature map.
- spatial_scale (float) – Multiplicative factor mapping ROI coordinates to
input spatial coordinates. Defaults to
1.0. - sampling_ratio (float) – Number of sampling points per bin in each direction.
0means adaptive (ceil(bin_size)). Defaults to0.0. - aligned (bool) – If
True, applies a half-pixel offset to ROI coordinates for more precise alignment. Defaults toFalse. - mode (str) – Pooling mode, either
"AVG"or"MAX". Defaults to"AVG".
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor with shape
-
Returns:
-
A symbolic tensor with shape
[M, output_height, output_width, C]. -
Raises:
-
ValueError – If
inputis not rank 4,roisis not rank 2 with 5 columns, ormodeis invalid. -
Return type:
round()
max.experimental.functional.round(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
rsqrt()
max.experimental.functional.rsqrt(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
scatter()
max.experimental.functional.scatter(input, updates, indices, axis=-1)
Creates a new symbolic tensor where the updates are written to input according to indices.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to write elements to.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of elements to write to input.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The positions in input to update.
- axis (int) – The axis along which indices indexes into.
-
Returns:
-
A new symbolic tensor representing the result of the scatter operation.
-
Raises:
-
ValueError – If
axisis out of range, if dtypes mismatch, ifindicesdtype is not int32/int64, or if any input is on a non-CPU device andstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
scatter_add()
max.experimental.functional.scatter_add(input, updates, indices, axis=-1)
Creates a new symbolic tensor by accumulating updates into input at indices.
Produces an output tensor by scattering elements from updates into input
according to indices, summing values at duplicate indices. For a 2-D
input with axis=0 the update rule is:
output[indices[i][j]][j] += updates[i][j]and with axis=1:
output[i][indices[i][j]] += updates[i][j]-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to accumulate into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to add.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The positions in input to update.
- axis (int) – The axis along which indices indexes into.
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Raises:
-
ValueError – If
axisis out of range, if dtypes mismatch, ifindicesdtype is not int32/int64, or if any input is on a non-CPU device andstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
scatter_max()
max.experimental.functional.scatter_max(input, updates, indices, axis=-1)
Creates a new symbolic tensor by scattering the maximum of updates into input.
Produces an output tensor by scattering elements from updates into input
according to indices, keeping the maximum at duplicate indices. For a 2-D
input with axis=0 the update rule is:
output[indices[i][j]][j] = max(output[indices[i][j]][j], updates[i][j])and with axis=1:
output[i][indices[i][j]] = max(output[i][indices[i][j]], updates[i][j])-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to scatter into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to compare.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The positions in input to update.
- axis (int) – The axis along which indices indexes into.
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Raises:
-
ValueError – If
axisis out of range, if dtypes mismatch, ifindicesdtype is not int32/int64, or if any input is on a non-CPU device andstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
scatter_min()
max.experimental.functional.scatter_min(input, updates, indices, axis=-1)
Creates a new symbolic tensor by scattering the minimum of updates into input.
Produces an output tensor by scattering elements from updates into input
according to indices, keeping the minimum at duplicate indices. For a 2-D
input with axis=0 the update rule is:
output[indices[i][j]][j] = min(output[indices[i][j]][j], updates[i][j])and with axis=1:
output[i][indices[i][j]] = min(output[i][indices[i][j]], updates[i][j])-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to scatter into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to compare.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The positions in input to update.
- axis (int) – The axis along which indices indexes into.
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Raises:
-
ValueError – If
axisis out of range, if dtypes mismatch, ifindicesdtype is not int32/int64, or if any input is on a non-CPU device andstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
scatter_mul()
max.experimental.functional.scatter_mul(input, updates, indices, axis=-1)
Creates a new symbolic tensor by scattering the product of updates into input.
Produces an output tensor by scattering elements from updates into input
according to indices, multiplying values at duplicate indices. For a 2-D
input with axis=0 the update rule is:
output[indices[i][j]][j] *= updates[i][j]and with axis=1:
output[i][indices[i][j]] *= updates[i][j]-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to scatter into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to multiply.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The positions in input to update.
- axis (int) – The axis along which indices indexes into.
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Raises:
-
ValueError – If
axisis out of range, if dtypes mismatch, ifindicesdtype is not int32/int64, or if any input is on a non-CPU device andstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
scatter_nd()
max.experimental.functional.scatter_nd(input, updates, indices)
Creates a new symbolic tensor where the updates are scattered into input at specified indices.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to write elements to.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of elements to write to input.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A tensor of indices specifying where to write updates. Shape should be [num_updates, rank] for full indexing or [num_updates, k] for partial indexing where k < rank.
-
Returns:
-
A new symbolic tensor representing the result of the scatter_nd operation.
-
Return type:
scatter_nd_add()
max.experimental.functional.scatter_nd_add(input, updates, indices)
Creates a new symbolic tensor by accumulating updates into input at N-D indices.
Produces an output tensor by scattering slices from updates into a copy
of input according to N-dimensional index vectors, summing values at
duplicate index positions. Each index vector is the last dimension of
indices and selects a slice (or scalar) in input.
Example for input.shape = [4, 2], indices.shape = [3, 1]
(1-D partial indexing, writes whole rows):
output[indices[i, 0], :] += updates[i, :]-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to accumulate into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to add.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An index tensor whose last dimension is the index vector
length
k(k <= input.rank).
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Return type:
scatter_nd_max()
max.experimental.functional.scatter_nd_max(input, updates, indices)
Creates a new symbolic tensor by scattering the maximum of updates into input at N-D indices.
Produces an output tensor by scattering slices from updates into a copy
of input according to N-dimensional index vectors, keeping the maximum
at duplicate index positions. Each index vector is the last dimension of
indices and selects a slice (or scalar) in input.
Example for input.shape = [4, 2], indices.shape = [3, 1]
(1-D partial indexing, writes whole rows):
output[indices[i, 0], :] = max(output[indices[i, 0], :], updates[i, :])-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to scatter into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to compare.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An index tensor whose last dimension is the index vector
length
k(k <= input.rank).
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Return type:
scatter_nd_min()
max.experimental.functional.scatter_nd_min(input, updates, indices)
Creates a new symbolic tensor by scattering the minimum of updates into input at N-D indices.
Produces an output tensor by scattering slices from updates into a copy
of input according to N-dimensional index vectors, keeping the minimum
at duplicate index positions. Each index vector is the last dimension of
indices and selects a slice (or scalar) in input.
Example for input.shape = [4, 2], indices.shape = [3, 1]
(1-D partial indexing, writes whole rows):
output[indices[i, 0], :] = min(output[indices[i, 0], :], updates[i, :])-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to scatter into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to compare.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An index tensor whose last dimension is the index vector
length
k(k <= input.rank).
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Return type:
scatter_nd_mul()
max.experimental.functional.scatter_nd_mul(input, updates, indices)
Creates a new symbolic tensor by scattering the product of updates into input at N-D indices.
Produces an output tensor by scattering slices from updates into a copy
of input according to N-dimensional index vectors, multiplying values
at duplicate index positions. Each index vector is the last dimension of
indices and selects a slice (or scalar) in input.
Example for input.shape = [4, 2], indices.shape = [3, 1]
(1-D partial indexing, writes whole rows):
output[indices[i, 0], :] *= updates[i, :]-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to scatter into.
- updates (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – A symbolic tensor of values to multiply.
- indices (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – An index tensor whose last dimension is the index vector
length
k(k <= input.rank).
-
Returns:
-
A new symbolic tensor with the same shape and dtype as input.
-
Return type:
sigmoid()
max.experimental.functional.sigmoid(x)
Computes the elementwise sigmoid activation of a symbolic tensor.
Creates a new op node to compute the elementwise sigmoid of a symbolic
tensor and adds it to the graph, returning the symbolic result. Sigmoid
is defined as sigmoid(x) = 1 / (1 + exp(-x)), mapping all input values
to the range (0, 1).
The sigmoid function is commonly used for binary classification tasks and as an activation function in neural networks, particularly in output layers for probability prediction.
import max.experimental.functional as F
from max.experimental.tensor import Tensor
# Create input tensor
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])
# Apply sigmoid activation
result = F.sigmoid(x)
print(result)
# Output: [[0.119, 0.269, 0.5], [0.731, 0.881, 0.953]]
# All values mapped to range (0, 1)-
Parameters:
-
x (TensorValue) – The symbolic tensor to use as the input to the sigmoid computation.
-
Returns:
-
A new symbolic tensor value representing the output of the sigmoid computation.
-
Raises:
-
Error – If the symbol doesn’t represent a tensor value.
-
Return type:
silu()
max.experimental.functional.silu(x)
Computes the elementwise silu of a symbolic tensor.
Creates a new op node to compute the elementwise silu of a symbolic tensor and adds it to the graph, returning the symbolic result.
silu is defined as silu(x) = x * sigmoid(x).
-
Parameters:
-
x (TensorValue) – The symbolic tensor to use as the input to the silu computation.
-
Returns:
-
A new symbolic tensor value representing the output of the silu computation.
-
Raises:
-
Error – If the symbol doesn’t represent a tensor value.
sin()
max.experimental.functional.sin(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
slice_tensor()
max.experimental.functional.slice_tensor(x, indices)
Slices out a subtensor view of the input tensor based on indices.
The semantics of slice_tensor() follow NumPy slicing semantics with the
following restrictions:
- Slice indices must not index out of
[-dim - 1, dim - 1]for negative step, or[-dim, dim]for positive step.
# Reverse a tensor.
slice_tensor(x, [slice(None, None, -1)])
# Unsqueeze the second last dimension of a tensor.
slice_tensor(x, [..., None, slice(None)])-
Returns:
-
The sliced subtensor of x.
-
Parameters:
-
- x (TensorValue)
- indices (SliceIndices)
-
Return type:
softmax()
max.experimental.functional.softmax(value, axis=-1)
-
Parameters:
-
Return type:
split()
max.experimental.functional.split(x, split_size_or_sections, axis=0)
Split a tensor into chunks along an axis.
When split_size_or_sections is an int, splits into equal chunks
(last chunk may be smaller). When it is a list of ints, splits
into chunks with exactly those sizes.
spmd_dispatch()
max.experimental.functional.spmd_dispatch(graph_op, args, output_mappings)
Per-shard graph op dispatch for distributed tensors.
Runs graph_op once per shard, extracting per-shard TensorValues
from each distributed Tensor arg. Reassembles the per-shard results
into distributed Tensors.
Also used by custom op dispatch (_custom_dispatch.py).
sqrt()
max.experimental.functional.sqrt(x)
Computes the elementwise square root of a symbolic tensor.
Creates a new op node to compute the elementwise square root of a symbolic tensor and adds it to the graph, returning the symbolic result. Square root is commonly used in normalization operations, distance calculations, and implementing mathematical operations like standard deviation.
import max.experimental.functional as F
from max.experimental.tensor import Tensor
# Create tensor with positive values
x = Tensor.constant([1.0, 4.0, 9.0, 16.0])
# Compute square root
result = F.sqrt(x)
print(result)
# Output: [1.0, 2.0, 3.0, 4.0]
# Note: sqrt requires non-negative values
# For tensors with negative values, use abs first:
y = Tensor.constant([1.0, -4.0, 9.0, -16.0])
result2 = F.sqrt(F.abs(y))
print(result2)
# Output: [1.0, 2.0, 3.0, 4.0]-
Parameters:
-
- value – The symbolic tensor to use as the input to the sqrt computation. If it’s not a floating-point DType, an exception will be raised.
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Returns:
-
A new symbolic tensor value representing the output of the sqrt value computation.
-
Raises:
-
Error – If the symbol doesn’t represent a tensor value.
-
Return type:
squeeze()
max.experimental.functional.squeeze(x, axis)
Removes a size-1 dimension from a symbolic tensor.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to squeeze.
- axis (int) – The dimension to remove from the input’s shape. If negative, this
indexes from the end of the tensor. For example,
squeeze(v, -1)squeezes the last dimension.
-
Returns:
-
A symbolic tensor with the same number of elements as the input tensor, and whose rank is 1 less than the rank of the input tensor.
-
Return type:
stack()
max.experimental.functional.stack(values, axis=0)
Stacks a list of tensors along a new axis.
-
Parameters:
-
- values (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray]) – A list of symbolic tensor values. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension.
- axis (int) – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape plus 1. For instance,
stack(vs, -1)will create and stack along a new axis as the last dimension, aadstack(vs, -2)will create and stack along a new dimension which is inserted immediately before the last dimension.
-
Returns:
-
A new symbolic tensor representing the result of the stack. It will have rank
n+1wherenis the rank of each input tensor. Its size on each dimension other thanaxiswill be the same as each input tensors’, with the new axis inserted. Along the new dimension it will have sizelen(values). -
Return type:
sub()
max.experimental.functional.sub(lhs, rhs)
sum()
max.experimental.functional.sum(x, axis=-1)
tanh()
max.experimental.functional.tanh(x)
Computes the elementwise tanh (hyperbolic tangent) of a symbolic tensor.
Creates a new op node to compute the elementwise tanh of a symbolic tensor
and adds it to the graph, returning the symbolic result. Tanh is defined as
tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)), mapping all input
values to the range (-1, 1).
The tanh function is commonly used as an activation function in recurrent neural networks (RNNs) and as a hidden layer activation in feedforward networks. Unlike sigmoid which maps to (0, 1), tanh is zero-centered, which can help with gradient flow during training.
import max.experimental.functional as F
from max.experimental.tensor import Tensor
# Create input tensor
x = Tensor.constant([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]])
# Apply tanh activation
result = F.tanh(x)
print(result)
# Output: [[-0.964, -0.762, 0.0], [0.762, 0.964, 0.995]]
# All values mapped to range (-1, 1)-
Parameters:
-
- value – The symbolic tensor to use as the input to the tanh computation. If it’s not a floating-point DType, an exception will be raised.
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Returns:
-
A new symbolic tensor value representing the output of the tanh value computation.
-
Raises:
-
Error – If the symbol doesn’t represent a tensor value.
-
Return type:
tensor_to_layout()
max.experimental.functional.tensor_to_layout(t)
Convert a Tensor to a TensorLayout for sharding rule evaluation.
-
Parameters:
-
t (Tensor)
-
Return type:
-
TensorLayout
tile()
max.experimental.functional.tile(x, repeats)
Returns a new tensor by tiling the input along each dimension.
The input is copied N_i times on the i-th dimension, where
N_i = repeats[i]. The i-th dimension of the output shape is the
i-th dimension of the input shape multiplied by N_i.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to tile.
- repeats (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – An iterable of repeat counts, one per dimension of
x. All values must be positive. The length must equal the rank ofx.
-
Returns:
-
A symbolic tensor whose i-th dimension size equals
x.shape[i] * repeats[i]. -
Raises:
-
ValueError – If the length of
repeatsdoes not match the rank ofx, or if any repeat value is not positive. Also raised for GPU inputs whenstrict_device_placement=DevicePlacementPolicy.Error. -
Return type:
top_k()
max.experimental.functional.top_k(input, k, axis=-1)
Returns tensor with only top K values along given axis.
-
Parameters:
-
Returns:
-
Top K values, Top K indices
-
Return type:
transfer_to()
max.experimental.functional.transfer_to(t, target)
Move or transfer_to a tensor to a target device or mapping.
This is the single entry point for ALL placement transitions:
- Non-distributed → distributed: scatters across the target mesh.
- Same mesh: uses collectives (allreduce, allgather, reduce-scatter, local split) to transition between placements.
- Cross mesh: gathers to Replicated, transfers to target mesh, then scatters.
- Any → single device: pass a
Device.
transpose()
max.experimental.functional.transpose(x, axis_1, axis_2)
Transposes two axes of a symbolic tensor.
For more information, see transpose().
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to transpose.
- axis_1 (int) – One of the two axes to transpose. If negative, this indexes
from the end of the tensor. For example,
transpose(v, -1, -2)transposes the last two axes. - axis_2 (int) – The other axis to transpose. May also be negative to index from the end of the tensor.
-
Returns:
-
A new symbolic tensor with the two specified axes transposed. It has the same elements and dtype, but the order of the elements is different according to the transposition.
-
Return type:
trunc()
max.experimental.functional.trunc(x)
-
Parameters:
-
x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
-
Return type:
uniform()
max.experimental.functional.uniform(shape=(), range=(0, 1), *, dtype=None, device=None)
Sample from a uniform distribution over [range[0], range[1]).
uniform_like()
max.experimental.functional.uniform_like(like, range=(0, 1))
Sample uniform values matching the shape and dtype of like.
unsqueeze()
max.experimental.functional.unsqueeze(x, axis)
Inserts a size-1 dimension into a symbolic tensor.
-
Parameters:
-
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input symbolic tensor to unsqueeze.
- axis (int) – The index at which to insert a new dimension into the input’s
shape. Elements at that index or higher are shifted back.
If negative, it indexes relative 1 plus the rank of the tensor.
For example,
unsqueeze(v, -1)adds a new dimension at the end, andunsqueeze(v, -2)inserts the dimension immediately before the last dimension.
-
Returns:
-
A symbolic tensor with the same number of elements as the input tensor, whose rank is 1 larger than the rank of the input tensor. The result’s shape at the
axisdimension is a static dimension of size 1. -
Return type:
where()
max.experimental.functional.where(condition, x, y)
Returns element-wise condition ? x : y for input tensors condition, x, and y.
-
Parameters:
-
- condition (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The condition tensor to use for selecting elementwise values. This tensor must have a boolean dtype.
- x (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – If the condition is true at a position, the value from the same position in this tensor will be selected.
- y (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – If the condition is false at a position, the value from the same position in this tensor will be selected.
-
Returns:
-
A new symbolic tensor holding either values from either
xory, based on the elements incondition. -
Return type:
while_loop()
max.experimental.functional.while_loop(initial_values, predicate, body)
Wrap predicate/body so Tensor returns are unwrapped to TensorValue.
ops.while_loop expects predicate and body functions that return
TensorValue, but our high-level API lets users return
Tensor. This wrapper calls __tensorvalue__() on each
returned Tensor before passing results to the graph op.
-
Parameters:
-
- initial_values (Iterable[Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray] | Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray)
- predicate (Callable[[...], Tensor])
- body (Callable[[...], Tensor | list[Tensor]])
-
Return type:
zeros()
max.experimental.functional.zeros(shape, *, dtype=None, device=None)
Create an all-zeros tensor, optionally distributed across devices.
zeros_like()
max.experimental.functional.zeros_like(like)
Create an all-zeros tensor matching the shape and dtype of like.
-
Parameters:
-
like (Tensor | TensorType | DistributedTensorType)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!