Skip to main content

Python module

tensor

Provides experimental tensor operations with eager execution capabilities.

This module provides the tensor class which supports eager execution of tensor operations, complementing the graph-based execution model provided by graph. The tensor operations automatically compile and execute using the MAX runtime.

Tensor is designed to be a high-performance NumPy replacement for programming accelerators. It implements numerics through high-performance Mojo kernels JIT-compiled for the hardware available, including state-of-the-art optimizations like automatic kernel fusion. It is intended as a building block for programming large, heterogeneous clusters of accelerators.

Key Features:

  • Eager semantics: Operations give immediate results for quick iteration and feedback.
  • High performance: All operations use high-performance Mojo implementations compiled specifically for the available hardware.
  • Automatic compilation: Tensors are compiled and optimized automatically. Operations may be easily fused into larger graphs to take advantage of the graph compiler’s automatic fusions.
  • Lazy evaluation: Tensors may be computed lazily until their values are needed.
  • NumPy compatibility: Supports common NumPy-like operations and indexing.
  • Device management: Supports common NumPy-like operations and indexing.

Create and manipulate tensors with automatic compilation and optimization:

from max.experimental import Tensor
from max.driver import CPU
from max.dtype import DType

x = Tensor.ones((2, 3), dtype=DType.float32, device=CPU())
y = Tensor.zeros_like(x)
result = x + y  # Eager execution with automatic compilation

Operations may be combined into a single execution graph to take advantage of automatic kernel fusion:

from max.experimental import functional as F

@F.functional
def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
    return x @ weight.T + bias

# Create and operate on tensors
x = Tensor.ones([2, 3])
weight = Tensor.ones([6, 3])
bias = Tensor.ones([6])

# Eager execution with a single fused graph
result = linear(x, weight, bias)

Users may opt in to lazy execution. This is primarily useful for

  1. Operations which may never execute, for instance creating modules with randomly initialized weights before loading weights
  2. Combining many operations into a single execution
from max.nn.module_v3 import Linear

with F.lazy():
    model = Linear(2, 3)

print(model)  # Lazy weights not initialized

# Load pretrained weights
weights =  {
    "weight": Tensor.zeros([3, 2]),
    "bias": Tensor.zeros([3]),
}
model.load_state_dict(weights)

# Or compile directly without ever initializing weights
from max.graph import TensorType
input_type = TensorType(DType.float32, ["batch", 2], CPU())
model = model.compile(input_type, weights=weights)

RealizationContext

class max.experimental.tensor.RealizationContext(*args, **kwargs)

Implements a way to realize unrealized tensors.

Most users should never have to think about the existence of this type. It exists to facilitate optimizations around where and when tensor operations are executed.

  • Each tensor is either real or associated with a RealizationContext.
  • If a tensor is not real, ie. “unrealized”, then it is backed by some symbolic computation.
  • The RealizationContext is responsible for tracking this symbolic computation and “realizing” the tensor (executing the computation and backing the tensor with real data) if and when it is asked to do so.
  • A RealizationContext can only realize tensors associated with it.

RealizationContext abstracts over various semantics of tensor construction.

“Eager” execution: tensors are realized as soon as the realization context exits. This is the default behavior.

This has a huge concrete advantage over eagerly executing one operation at a time: by controlling the boundary of where the eager context starts and ends, we can give advanced users a tool to _enable fine-grained bounds for automatic fusion!

In practice the easiest way to do this is to mark a function as F.functional. This function is then assumed to be “atomic” for the purposes of eager execution. All ops within the function execute as part of the same graph, meaning the compiler is free to fuse operations and generate fused kernels within this region.

“Lazy” execution: tensors are realized only when code later tries to use them.

This enables a class of interface design common in the ML world, in which layers are constructed with randomized weights which are never used. Lazy execution neatly allows constructing entire models, only performing the weight initialization and allocating memory for them if and when those weights are actually used.

Graph compilation: tensors may never be realized.

This allows tensor operations to be composed with direct usage of the Graph API, for instance Module.compile, or using F.* operations in another Graph API usage.

Async execution: Tensors are realized as async functions, allowing clean integration in async systems like web services.

add_source()

add_source(tensor)

Adds a realized tensor as a “source” of the realization state, ie. one on whose values unrealized tensors depend.

Parameters:

tensor (Tensor) – The realized tensor to add as a source to the computation.

Returns:

A realization state for the tensor. This may be used to compute downstream unrealized values. _If it is used in any mutating operations, it should be assigned to tensor.state to mark the tensor as having been mutated.

Return type:

RealizationState

create_unrealized()

create_unrealized(value)

Registers an unrealized graph value with the realization context and returns it as an unrealized tensor.

Parameters:

value (BufferValue | TensorValue) – The graph value representing the result of a computation.

Returns:

A new tensor associated with the unrealized value.

Return type:

Tensor

graph

graph: Graph

The graph used by the realization context.

realize_all()

async realize_all()

Realizes all unrealized tensors associated with this context.

Return type:

list[Tensor]

RealizationState

class max.experimental.tensor.RealizationState(value, ctx)

State for an unrealized tensor.

See RealizationContext.

Parameters:

ctx

ctx: RealizationContext

The realization context used to create this tensor. This context is responsible for realizing the tensor to a real value.

value

value: BufferValue | TensorValue

The symbolic value representing the computation backing this tensor.

Tensor

class max.experimental.tensor.Tensor(*, storage=None, state=None)

A multi-dimensional array with eager execution and automatic compilation.

The Tensor class provides a high-level interface for numerical computations with automatic compilation and optimization via the MAX runtime. Operations on tensors execute eagerly while benefiting from lazy evaluation and graph-based optimizations behind the scenes.

Key Features:

  • Eager execution: Operations execute immediately with automatic compilation.
  • Lazy evaluation: Computation may be deferred until results are needed.
  • High performance: Uses the Mojo compiler and optimized kernels.
  • NumPy-like API: Supports familiar array operations and indexing.
  • Device flexibility: Works seamlessly across CPU and accelerators.

Creating Tensors:

Create tensors using factory methods like ones(), zeros(), constant(), arange(), or from other array libraries via from_dlpack().

from max.experimental import tensor
from max.dtype import DType

# Create tensors with factory methods
x = tensor.Tensor.ones((2, 3), dtype=DType.float32)
y = tensor.Tensor.zeros((2, 3), dtype=DType.float32)

# Perform operations
result = x + y  # Eager execution with automatic compilation

# Access values
print(result.shape)  # (2, 3)
print(result.dtype)  # DType.float32

Implementation Notes:

Tensors use lazy evaluation internally - they don’t always hold concrete data in memory. A tensor may be “unrealized” (not yet computed) until its value is actually needed (e.g., when converting to other formats or calling item()). This allows the runtime to optimize sequences of operations efficiently.

Operations on tensors build a computation graph behind the scenes, which is compiled and executed when needed. All illegal operations fail immediately with clear error messages, ensuring a smooth development experience.

Interoperability:

Tensors support the DLPack protocol for zero-copy data exchange with NumPy, PyTorch, JAX, and other array libraries. Use from_dlpack() to import arrays and standard DLPack conversion for export.

Parameters:

T

property T: Tensor

Gets the transposed tensor.

Returns a tensor with the last two dimensions transposed. This is equivalent to calling transpose(-1, -2) and is commonly used for matrix operations.

Returns:

A tensor with the last two dimensions swapped.

Return type:

Tensor

arange()

classmethod arange(start=0, stop=None, step=1, *, dtype=None, device=None)

Creates a tensor with evenly spaced values within a given interval.

Returns a new 1D tensor containing a sequence of values starting from start (inclusive) and ending before stop (exclusive), with values spaced by step. This is similar to Python’s built-in range() function and NumPy’s arange().

from max.experimental import tensor
from max.dtype import DType

# Create a range from 0 to 10 (exclusive)
x = tensor.Tensor.arange(10)
# Result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# Create a range from 5 to 15 with step 2
y = tensor.Tensor.arange(5, 15, 2)
# Result: [5, 7, 9, 11, 13]

# Use a specific dtype
z = tensor.Tensor.arange(0, 5, dtype=DType.float32)
# Result: [0.0, 1.0, 2.0, 3.0, 4.0]

# Create a range with float step (like numpy/pytorch)
w = tensor.Tensor.arange(0.0, 1.0, 0.2, dtype=DType.float32)
# Result: [0.0, 0.2, 0.4, 0.6, 0.8]

# Create a descending range with negative step
v = tensor.Tensor.arange(5, 0, -1, dtype=DType.float32)
# Result: [5.0, 4.0, 3.0, 2.0, 1.0]

Parameters:

  • start (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The starting value of the sequence. If stop is not provided, this becomes the stop value and start defaults to 0.
  • stop (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – The end value of the sequence (exclusive). If not specified, the sequence ends at start and begins at 0.
  • step (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The spacing between values in the sequence. Must be non-zero.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A 1D tensor containing the evenly spaced values.

Return type:

Tensor

argmax()

argmax(axis=-1)

Finds the indices of the maximum values along an axis.

Returns a tensor containing the indices of the maximum values along the specified axis. This is useful for finding the position of the largest element, such as determining predicted classes in classification.

from max.experimental import tensor
from max.dtype import DType

# Create a 2x4 tensor
x = tensor.Tensor.constant(
    [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32
)

# Find argmax along last axis (within each row)
indices = x.argmax(axis=-1)
# Result: [1, 2] (index 1 in first row, index 2 in second row)

Parameters:

axis (int) – The axis along which to find the maximum indices. Defaults to -1 (the last axis).

Returns:

A tensor containing the indices of the maximum values.

Return type:

Tensor

broadcast_to()

broadcast_to(shape)

Broadcasts the tensor to the specified shape.

Returns a tensor broadcast to the target shape, following NumPy broadcasting semantics. Dimensions of size 1 in the input can be expanded to match larger dimensions in the target shape.

This is equivalent to PyTorch’s torch.broadcast_to() and torch.Tensor.expand().

from max.experimental import tensor
from max.dtype import DType

# Create a tensor with shape (3, 1)
x = tensor.Tensor.ones([3, 1], dtype=DType.float32)

# Broadcast to (3, 4) - expands the second dimension
y = x.broadcast_to([3, 4])
print(y.shape)  # (3, 4)

# Add a new leading dimension
w = x.broadcast_to([2, 3, 1])
print(w.shape)  # (2, 3, 1)

Parameters:

shape (Iterable[int | str | Dim | integer[Any]]) – The target shape. Each dimension must either match the input dimension or be broadcastable from size 1.

Returns:

A tensor broadcast to the specified shape.

Return type:

Tensor

cast()

cast(dtype)

Casts the tensor to a different data type.

Returns a new tensor with the same values but a different data type. This is useful for type conversions between different numeric types, such as converting float32 to int32 for indexing operations or float32 to bfloat16 for memory-efficient computations.

from max.experimental import tensor
from max.dtype import DType

# Create a float32 tensor
x = tensor.Tensor.constant([1.7, 2.3, 3.9], dtype=DType.float32)
print(x.dtype)  # DType.float32

# Cast to int32 (truncates decimal values)
y = x.cast(DType.int32)
print(y.dtype)  # DType.int32
# Values: [1, 2, 3]

Parameters:

dtype (DType) – The target data type for the tensor.

Returns:

A new tensor with the specified data type.

Return type:

Tensor

clip()

clip(*, min=None, max=None)

Clips values outside a range to the boundaries of the range.

from max.experimental import tensor

# Create a 2x4 tensor
x = tensor.Tensor.constant(
    [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]]
)

# Find max along last axis (within each row)
clipped_above = x.clip(max=3.)
# Result: [[1.2, 3., 2.1, 0.8], [2.3, 1.9, 3, 3.]]

clipped_below = x.clip(min=3.)
# Result: [[3., 3.5, 3., 3.], [3., 3., 4.2, 3.]]

Parameters:

Returns:

A tensor containing the values clipped to the specified range.

Return type:

Tensor

constant()

classmethod constant(value, *, dtype=None, device=None)

Creates a constant tensor from a scalar, array, or nested list.

Constructs a tensor with constant values that can be a scalar, a nested Python list, or a DLPack-compatible array. The shape is automatically inferred from the input data structure.

from max.experimental import tensor
from max.dtype import DType

# Create from scalar
x = tensor.Tensor.constant(42, dtype=DType.int32)

# Create from nested list
y = tensor.Tensor.constant([[1.0, 2.0], [3.0, 4.0]])

# Create from NumPy array
import numpy as np

z = tensor.Tensor.constant(np.array([1, 2, 3]))

Parameters:

  • value (DLPackArray | Sequence[float | number[Any] | Sequence[Number | NestedArray]] | float | number[Any]) – The constant value for the tensor. Can be a scalar number, a nested Python list, or any DLPack-compatible array.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A new tensor containing the constant value(s).

Return type:

Tensor

device

property device: Device

Gets the device where the tensor is stored.

Returns the device (CPU or accelerator) where the tensor’s data is located.

Returns:

The device where the tensor is stored.

Return type:

Device

driver_tensor

property driver_tensor: Tensor

A pointer to the underlying memory.

Raises if the tensor is unrealized.

dtype

property dtype: DType

Gets the data type of the tensor elements.

Returns the data type (dtype) of the elements stored in the tensor, such as float32, int32, or bfloat16.

Returns:

The data type of the tensor elements.

Return type:

DType

from_dlpack()

classmethod from_dlpack(array)

Creates a tensor from a DLPack array.

Constructs a tensor by importing data from any object that supports the DLPack protocol (such as NumPy arrays and PyTorch tensors). This enables zero-copy interoperability with other array libraries.

import numpy as np
from max.experimental import tensor

# Create a NumPy array
np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

# Convert to MAX tensor via DLPack
x = tensor.Tensor.from_dlpack(np_array)

Parameters:

array (DLPackArray) – Any object supporting the DLPack protocol, such as NumPy arrays, PyTorch tensors, or JAX arrays.

Returns:

A new tensor containing the data from the DLPack array.

Return type:

Tensor

from_graph_value()

classmethod from_graph_value(value)

Creates a tensor from a graph value.

Constructs a tensor from an existing graph value, which can be either a TensorValue or BufferValue. This is used for converting graph level values into tensor objects. The new tensor is registered as unrealized, backed by the current realization context.

Parameters:

value (Value) – The graph value to wrap. Can be either a TensorValue or BufferValue from the MAX graph API.

Returns:

A new tensor backed by the provided graph value.

Return type:

Tensor

full()

classmethod full(shape, value, *, dtype=None, device=None)

Creates a tensor filled with a specified value.

Returns a new tensor with the given shape where all elements are initialized to the specified value. This is useful for creating tensors with uniform values other than zero or one.

from max.experimental import tensor
from max.dtype import DType

# Create a 3x3 tensor filled with 7
x = tensor.Tensor.full((3, 3), value=7, dtype=DType.int32)

# Create a 2x4 tensor filled with pi
y = tensor.Tensor.full((2, 4), value=3.14159)

Parameters:

  • shape (Iterable[int | str | Dim | integer[Any]]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape.
  • value (float | number[Any]) – The scalar value to fill the tensor with.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A new tensor with the specified shape filled with the given value.

Return type:

Tensor

full_like()

classmethod full_like(input, value)

Creates a tensor filled with a value, matching a given tensor’s properties.

Returns a new tensor filled with the specified value that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s full_like and PyTorch’s full_like.

from max.experimental import tensor
from max.dtype import DType

# Create a reference tensor
ref = tensor.Tensor.ones([2, 3], dtype=DType.float32)

# Create tensor filled with 5.0 matching the reference tensor
x = tensor.Tensor.full_like(ref, value=5.0)

Parameters:

  • input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.
  • value (float | number[Any]) – The scalar value to fill the tensor with.

Returns:

A new tensor filled with the specified value, matching the properties of the input.

Return type:

Tensor

item()

item()

Gets the scalar value from a single-element tensor.

Extracts and returns the scalar value from a tensor containing exactly one element. The tensor is realized if needed and transferred to CPU before extracting the value.

Returns:

The scalar value from the tensor. The return type matches the tensor’s dtype (e.g., float for float32, int for int32).

Raises:

TypeError – If the tensor contains more than one element.

max()

max(axis=-1)

Computes the maximum values along an axis.

Returns a tensor containing the maximum values along the specified axis. This is useful for reduction operations and finding peak values in data.

from max.experimental import tensor
from max.dtype import DType

# Create a 2x4 tensor
x = tensor.Tensor.constant(
    [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]], dtype=DType.float32
)

# Find max along last axis (within each row)
row_max = x.max(axis=-1)
# Result: [3.5, 4.2]

# Find max along first axis (within each column)
col_max = x.max(axis=0)
# Result: [2.3, 3.5, 4.2, 3.1]

Parameters:

axis (int) – The axis along which to compute the maximum. Defaults to -1 (the last axis).

Returns:

A tensor containing the maximum values along the specified axis.

Return type:

Tensor

mean()

mean(axis=-1)

Computes the mean values along an axis.

Returns a tensor containing the arithmetic mean of values along the specified axis. This is useful for computing averages, normalizing data, or aggregating statistics.

from max.experimental import tensor
from max.dtype import DType

# Create a 2x4 tensor
x = tensor.Tensor.constant(
    [[2.0, 4.0, 6.0, 8.0], [1.0, 3.0, 5.0, 7.0]], dtype=DType.float32
)

# Compute mean along last axis (within each row)
row_mean = x.mean(axis=-1)
# Result: [5.0, 4.0] (mean of each row)

# Compute mean along first axis (within each column)
col_mean = x.mean(axis=0)
# Result: [1.5, 3.5, 5.5, 7.5] (mean of each column)

Parameters:

axis (int) – The axis along which to compute the mean. Defaults to -1 (the last axis).

Returns:

A tensor containing the mean values along the specified axis.

Return type:

Tensor

num_elements()

num_elements()

Gets the total number of elements in the tensor.

Computes the product of all dimensions in the tensor’s shape to determine the total number of elements.

Returns:

The total number of elements in the tensor.

Return type:

int

ones()

classmethod ones(shape, *, dtype=None, device=None)

Creates a tensor filled with ones.

Returns a new tensor with the specified shape where all elements are initialized to one. The tensor is created with eager execution and automatic compilation.

from max.experimental import tensor
from max.driver import CPU
from max.dtype import DType

# Create a 2x3 tensor of ones
x = tensor.Tensor.ones((2, 3), dtype=DType.float32, device=CPU())
# Result: [[1.0, 1.0, 1.0],
#          [1.0, 1.0, 1.0]]

# Create a 1D tensor using default dtype and device
y = tensor.Tensor.ones((5,))

Parameters:

  • shape (Iterable[int | str | Dim | integer[Any]]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A new tensor with the specified shape filled with ones.

Return type:

Tensor

ones_like()

classmethod ones_like(input)

Creates a tensor of ones matching a given tensor’s properties.

Returns a new tensor filled with ones that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s ones_like and PyTorch’s ones_like.

from max.experimental import tensor
from max.dtype import DType

# Create a reference tensor
ref = tensor.Tensor.zeros([3, 4], dtype=DType.float32)

# Create ones tensor matching the reference tensor
x = tensor.Tensor.ones_like(ref)
# Result: 3x4 tensor of ones with dtype float32

Parameters:

input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.

Returns:

A new tensor filled with ones matching the properties of the input.

Return type:

Tensor

permute()

permute(dims)

Permutes the dimensions of the tensor.

Returns a tensor with its dimensions reordered according to the specified permutation. This is useful for changing the layout of multi-dimensional data.

Parameters:

dims (list[int]) – A list specifying the new order of dimensions. For example, [2, 0, 1] moves dimension 2 to position 0, dimension 0 to position 1, and dimension 1 to position 2.

Returns:

A tensor with permuted dimensions.

Return type:

Tensor

range_like()

classmethod range_like(type)

Creates a range tensor matching a given type’s properties.

Returns a new tensor containing sequential indices along the last dimension, broadcasted to match the shape of the specified tensor type. Each row (along the last dimension) contains values from 0 to the dimension size minus one. This is useful for creating position indices or coordinate tensors.

from max.experimental import tensor
from max.graph import TensorType
from max.driver import CPU
from max.dtype import DType

# Create a reference tensor type with shape (2, 4)
ref_type = TensorType(DType.int32, (2, 4), device=CPU())

# Create range tensor matching the reference type
x = tensor.Tensor.range_like(ref_type)
# Result: [[0, 1, 2, 3],
#          [0, 1, 2, 3]]

Parameters:

type (TensorType) – The tensor type to match. The returned tensor will have the same shape, dtype, and device as this type, with values representing indices along the last dimension.

Returns:

A new tensor with sequential indices broadcasted to match the input type’s shape.

Return type:

Tensor

rank

property rank: int

Gets the number of dimensions in the tensor.

Returns the rank (number of dimensions) of the tensor. For example, a scalar has rank 0, a vector has rank 1, and a matrix has rank 2.

Returns:

The number of dimensions in the tensor.

Return type:

int

real

property real: bool

realize

property realize: Tensor

Force the tensor to realize if it is not already.

reshape()

reshape(shape)

Reshapes the tensor to a new shape.

Returns a tensor with the same data but a different shape. The total number of elements must remain the same. This is useful for changing tensor dimensions for different operations, such as flattening a multi-dimensional tensor or converting a 1D tensor into a matrix.

from max.experimental import tensor
from max.dtype import DType

# Create a 2x3 tensor
x = tensor.Tensor.constant([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(x.shape)  # (2, 3)

# Flatten to 1D
y = x.reshape((6,))
print(y.shape)  # (6,)
# Values: [1, 2, 3, 4, 5, 6]

Parameters:

shape (Iterable[int | str | Dim | integer[Any]]) – The desired output shape. Can be a tuple or list of integers. The total number of elements must equal the original tensor’s element count.

Returns:

A reshaped tensor with the specified shape.

Return type:

Tensor

shape

property shape: Shape

Gets the shape of the tensor.

Returns the dimensions of the tensor as a shape object.

Returns:

The shape of the tensor.

Return type:

Shape

split()

split(split_size_or_sections, axis=0)

Splits the tensor into multiple tensors along a given dimension.

This method supports two modes, matching PyTorch’s behavior:

  • If split_size_or_sections is an int, splits into chunks of that size (the last chunk may be smaller if not evenly divisible).
  • If split_size_or_sections is a list of ints, splits into chunks with exactly those sizes (must sum to the dimension size).
from max.experimental import tensor
from max.dtype import DType

# Create a 10x4 tensor
x = tensor.Tensor.ones([10, 4], dtype=DType.float32)

# Split into chunks of size 3 (last chunk is size 1)
chunks = x.split(3, axis=0)
# Result: 4 tensors with shapes [3,4], [3,4], [3,4], [1,4]

# Split into exact sizes
chunks = x.split([2, 3, 5], axis=0)
# Result: 3 tensors with shapes [2,4], [3,4], [5,4]

Parameters:

  • split_size_or_sections (int | list[int]) – Either an int (chunk size) or a list of ints (exact sizes for each output tensor).
  • axis (int) – The dimension along which to split. Defaults to 0.

Returns:

A list of tensors resulting from the split.

Return type:

list[Tensor]

squeeze()

squeeze(axis)

Removes a size-1 dimension from the tensor.

Returns a tensor with the specified size-1 dimension removed. This is useful for removing singleton dimensions from tensors after operations that may have added them.

from max.experimental import tensor
from max.dtype import DType

# Create a tensor with a size-1 dimension
x = tensor.Tensor.ones([4, 1, 6], dtype=DType.float32)
print(x.shape)  # (4, 1, 6)

# Squeeze out the size-1 dimension
y = x.squeeze(axis=1)
print(y.shape)  # (4, 6)

Parameters:

axis (int) – The dimension to remove from the tensor’s shape. If negative, this indexes from the end of the tensor. The dimension at this axis must have size 1.

Returns:

A tensor with the specified dimension removed.

Return type:

Tensor

Raises:

ValueError – If the dimension at the specified axis is not size 1.

state

state: RealizationState | None

State for realizing an unrealized tensor.

storage

storage: Tensor | None

Underlying memory for a realized tensor. If the tensor is used in any mutating operations that have not been realized, this holds the state before any updates.

sum()

sum(axis=-1)

Computes the sum of values along an axis.

Returns a tensor containing the sum of values along the specified axis. This is a fundamental reduction operation used for aggregating data, computing totals, and implementing other operations like mean.

from max.experimental import tensor
from max.dtype import DType

# Create a 2x3 tensor
x = tensor.Tensor.constant(
    [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=DType.float32
)

# Sum along last axis (within each row)
row_sum = x.sum(axis=-1)
# Result: [6.0, 15.0] (sum of each row)

# Sum along first axis (within each column)
col_sum = x.sum(axis=0)
# Result: [5.0, 7.0, 9.0] (sum of each column)

Parameters:

axis (int) – The axis along which to compute the sum. Defaults to -1 (the last axis).

Returns:

A tensor containing the sum along the specified axis.

Return type:

Tensor

to()

to(device)

Transfers the tensor to a different device.

Creates a new tensor with the same data on the specified device. This allows moving tensors between CPU and accelerators or between different accelerator devices.

from max.experimental import tensor
from max.driver import CPU, Accelerator

# Create a tensor on CPU
x = tensor.Tensor.ones((2, 3), device=CPU())
print(x.device)  # CPU

# Transfer to accelerator
y = x.to(Accelerator())
print(y.device)  # Accelerator(0)

Parameters:

device (Device) – The target device for the tensor.

Returns:

A new tensor with the same data on the specified device.

Return type:

Tensor

transpose()

transpose(dim1, dim2)

Transposes two dimensions of the tensor.

Returns a tensor with the specified dimensions swapped. This is a special case of permutation that swaps exactly two dimensions.

Parameters:

  • dim1 (int) – The first dimension to swap.
  • dim2 (int) – The second dimension to swap.

Returns:

A tensor with the specified dimensions transposed.

Return type:

Tensor

type

property type: TensorType

Gets the tensor type information.

Returns the type information for the tensor, including shape, dtype, and device. If the underlying value is a buffer type, it’s converted to a tensor type.

Returns:

The type information for the tensor.

Return type:

TensorType

unsqueeze()

unsqueeze(axis)

Inserts a size-1 dimension into the tensor.

Returns a tensor with a new size-1 dimension inserted at the specified position. This is the inverse of squeeze() and is useful for adding dimensions needed for broadcasting or matrix operations.

from max.experimental import tensor
from max.dtype import DType

# Create a 1D tensor
x = tensor.Tensor.constant([1.0, 2.0, 3.0], dtype=DType.float32)
print(x.shape)  # (3,)

# Add dimension at the end
y = x.unsqueeze(axis=-1)
print(y.shape)  # (3, 1)

# Add dimension at the beginning
z = x.unsqueeze(axis=0)
print(z.shape)  # (1, 3)

Parameters:

axis (int) – The index at which to insert the new dimension. If negative, indexes relative to 1 plus the rank of the tensor. For example, axis=-1 adds a dimension at the end.

Returns:

A tensor with an additional size-1 dimension.

Return type:

Tensor

zeros()

classmethod zeros(shape, *, dtype=None, device=None)

Creates a tensor filled with zeros.

Returns a new tensor with the specified shape where all elements are initialized to zero. The tensor is created with eager execution and automatic compilation.

from max.experimental import tensor
from max.driver import CPU
from max.dtype import DType

# Create a 2x3 tensor of zeros
x = tensor.Tensor.zeros((2, 3), dtype=DType.float32, device=CPU())
# Result: [[0.0, 0.0, 0.0],
#          [0.0, 0.0, 0.0]]

# Create a 1D tensor using default dtype and device
y = tensor.Tensor.zeros((5,))

Parameters:

  • shape (Iterable[int | str | Dim | integer[Any]]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A new tensor with the specified shape filled with zeros.

Return type:

Tensor

zeros_like()

classmethod zeros_like(input)

Creates a tensor of zeros matching a given tensor’s properties.

Returns a new tensor filled with zeros that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s zeros_like and PyTorch’s zeros_like.

from max.experimental import tensor
from max.dtype import DType

# Create a reference tensor
ref = tensor.Tensor.ones([3, 4], dtype=DType.float32)

# Create zeros tensor matching the reference tensor
x = tensor.Tensor.zeros_like(ref)
# Result: 3x4 tensor of zeros with dtype float32

Parameters:

input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.

Returns:

A new tensor filled with zeros matching the properties of the input.

Return type:

Tensor

current_realization_context()

max.experimental.tensor.current_realization_context()

Return a value for the context variable for the current context.

If there is no value for the variable in the current context, the method will:
* return the value of the default argument of the method, if provided; or
  • return the default value for the context variable, if it was created with one; or
  • raise a LookupError.

default_device()

max.experimental.tensor.default_device(device)

Context manager for setting the default device for tensor creation.

Sets the default device used for tensor creation within the context. All tensors created inside the context block without an explicit device parameter will use this device.

from max.experimental import tensor
from max.driver import CPU

# Use CPU as default device in this context
with tensor.default_device(CPU()):
    x = tensor.Tensor.ones((2, 3))  # Created on CPU
    y = tensor.Tensor.zeros((2, 3))  # Also on CPU

Parameters:

device (Device | DeviceRef) – The device to use as the default for tensor creation within the context.

Returns:

A context manager that sets the default device.

default_dtype()

max.experimental.tensor.default_dtype(dtype)

Context manager for setting the default dtype for tensor creation.

Sets the default data type used for tensor creation within the context. All tensors created inside the context block without an explicit dtype parameter will use this data type.

from max.experimental import tensor
from max.dtype import DType

# Use int32 as default dtype in this context
with tensor.default_dtype(DType.int32):
    x = tensor.Tensor.ones((2, 3))  # Created with int32
    y = tensor.Tensor.zeros((2, 3))  # Also int32

Parameters:

dtype (DType) – The data type to use as the default for tensor creation within the context.

Returns:

A context manager that sets the default dtype.

defaults()

max.experimental.tensor.defaults(dtype=None, device=None)

Gets the default dtype and device for tensor creation.

Returns a tuple containing the dtype and device to use for tensor creation, applying defaults when values are not specified. If no dtype is provided, defaults to DType.float32 for CPU and DType.bfloat16 for accelerators. If no device is provided, defaults to an accelerator if available, otherwise CPU.

Parameters:

  • dtype (DType | None) – The data type to use. If not specified, a default dtype based on the device is returned.
  • device (Device | None) – The device to use. If not specified, defaults to an available accelerator or CPU.

Returns:

A tuple containing the resolved dtype and device.

Return type:

tuple[DType, Device]

defaults_like()

max.experimental.tensor.defaults_like(like)

Context manager setting the default dtype and device for tensor creation.

Sets the default data type and device used for tensor creation within the context. All tensors created inside the context block without explicit dtypes or devices will use these parameters.

from max.experimental import tensor
from max.driver import CPU
from max.dtype import DType

x = Tensor.zeros([1], dtype=DType.int32, device=CPU())
# Use int32 as default dtype in this context
with tensor.defaults_like(x):
    y = tensor.Tensor.zeros((2, 3))  # int32, cpu
    z = tensor.Tensor.zeros((2, 3), dtype=DType.float32)  # float32, cpu

Parameters:

  • tensor – A tensor to use as the default dtype and device for the context.
  • like (Tensor | TensorType)

Returns:

A context manager that sets the default dtype and device.

Return type:

Generator[None]

realization_context()

max.experimental.tensor.realization_context(ctx)

Sets the current realization context, within a context manager.

New tensors created within this block will use the given realization context to execute.

See RealizationContext.

Parameters:

ctx (RealizationContext) – The realization context to set as the current context.

Returns:

A context manager. When the context manager is entered, it will set ctx as the current realization context. When exited the current realization context will be reset to its previous value.

Return type:

AbstractContextManager[RealizationContext]

Was this page helpful?