Skip to main content

Python class

Tensor

Tensor

class max.experimental.tensor.Tensor(data=None, *, dtype=None, device=None, storage=None, state=None)

source

Bases: DLPackArray, HasTensorValue

A multi-dimensional array with eager execution and automatic compilation.

The Tensor class provides a high-level interface for numerical computations with automatic compilation and optimization via the MAX runtime. Operations on tensors execute eagerly while benefiting from lazy evaluation and graph-based optimizations behind the scenes.

Key Features:

  • Eager execution: Operations execute immediately with automatic compilation.
  • Lazy evaluation: Computation may be deferred until results are needed.
  • High performance: Uses the Mojo compiler and optimized kernels.
  • Familiar API: Supports common array operations and indexing.
  • Device flexibility: Works seamlessly across CPU and accelerators.

Creating Tensors:

Create tensors using the constructor, factory methods like ones(), zeros(), arange(), or from other array libraries via from_dlpack().

from max.experimental import tensor

x = tensor.Tensor([[1.0, 2.0], [3.0, 4.0]])
y = tensor.Tensor.zeros((2, 3))

# Perform operations
result = x + y  # Eager execution with automatic compilation

# Access values
print(result.shape)  # (2, 3)
print(result.dtype)  # DType.float32

Implementation Notes:

Tensors use lazy evaluation internally - they don’t always hold concrete data in memory. A tensor may be “unrealized” (not yet computed) until its value is actually needed (e.g., when converting to other formats or calling item()). This allows the runtime to optimize sequences of operations efficiently.

Operations on tensors build a computation graph behind the scenes, which is compiled and executed when needed. All illegal operations fail immediately with clear error messages, ensuring a smooth development experience.

Interoperability:

Tensors support the DLPack protocol for zero-copy data exchange with NumPy, PyTorch, JAX, and other array libraries. Use from_dlpack() to import arrays and standard DLPack conversion for export.

Creates a tensor from data or from internal storage.

When called with data, constructs a tensor from a scalar, nested list, or DLPack-compatible array (matching PyTorch’s torch.tensor() semantics). When called without data, requires exactly one of storage or state for internal construction.

For DLPack-compatible arrays (NumPy, PyTorch, etc.) the array’s own dtype is preserved by default; no silent precision conversion happens. For Python scalars and nested lists, dtype defaults to DType.float32 on CPU and DType.bfloat16 on accelerators.

from max.experimental.tensor import Tensor
from max.dtype import DType

# Create from scalar
x = Tensor(42, dtype=DType.int32)

# Create from nested list
y = Tensor([[1.0, 2.0], [3.0, 4.0]])

# Create from NumPy array; dtype is inherited from the array
import numpy as np
z = Tensor(np.array([1, 2, 3], dtype=np.int16))  # stays int16

Parameters:

  • data (DLPackArray | NestedArray | Number | None) – The value for the tensor. Can be a scalar number, a nested Python list, or any DLPack-compatible array (NumPy, PyTorch, etc.). If not provided, exactly one of storage or state must be supplied.
  • dtype (DType | None) – The data type for the tensor elements. For DLPack arrays this defaults to the array’s own dtype; passing a conflicting value raises ValueError. For Python scalars/lists this defaults to DType.float32 on CPU and DType.bfloat16 on accelerators.
  • device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU. Only valid when data is provided.
  • storage (driver.Buffer | None) – Internal backing buffer for a realized tensor. Mutually exclusive with data.
  • state (RealizationState | None) – Internal realization state for an unrealized tensor. Mutually exclusive with data.

Return type:

Tensor

T

property T: Tensor

source

Returns a tensor with the last two dimensions transposed.

This is equivalent to calling transpose(-1, -2), which swaps the last two dimensions of the tensor. For a 2D matrix, this produces the standard matrix transpose.

from max.experimental.tensor import Tensor
from max.dtype import DType

# Create a 2x3 matrix
x = Tensor([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3)]

# Use .T property (equivalent to transpose(-1, -2))
y = x.T
print(f"Transposed shape: {y.shape}")
# Output: Transposed shape: [Dim(3), Dim(2)]
print(y)

Returns:

A tensor with the last two dimensions transposed.

arange()

classmethod arange(start=0, stop=None, step=1, out_dim=None, *, dtype=None, device=None)

source

Creates a tensor with evenly spaced values within a given interval.

Returns a new 1D tensor containing a sequence of values starting from start (inclusive) and ending before stop (exclusive), with values spaced by step. This is similar to Python’s built-in range() function and NumPy’s arange().

from max.experimental import tensor
from max.dtype import DType

# Create a range from 0 to 10 (exclusive)
x = tensor.Tensor.arange(10)
# Result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# Create a range from 5 to 15 with step 2
y = tensor.Tensor.arange(5, 15, 2)
# Result: [5, 7, 9, 11, 13]

# Use a specific dtype
z = tensor.Tensor.arange(0, 5, dtype=DType.float32)
# Result: [0.0, 1.0, 2.0, 3.0, 4.0]

# Create a range with float step (like numpy/pytorch)
w = tensor.Tensor.arange(0.0, 1.0, 0.2)
# Result: [0.0, 0.2, 0.4, 0.6, 0.8]

# Create a descending range with negative step
v = tensor.Tensor.arange(5, 0, -1, dtype=DType.float32)
# Result: [5.0, 4.0, 3.0, 2.0, 1.0]

Parameters:

  • start (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The starting value of the sequence. If stop is not provided, this becomes the stop value and start defaults to 0.
  • stop (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – The end value of the sequence (exclusive). If not specified, the sequence ends at start and begins at 0.
  • step (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The spacing between values in the sequence. Must be non-zero.
  • out_dim (int | str | Dim | integer[Any] | TypedAttr | None) – The expected output dimension. Required when start, stop, or step are tensors rather than scalar literals. If not specified, the output dimension is computed from the scalar values of the inputs.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | DeviceMapping | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A 1D tensor containing the evenly spaced values.

Return type:

Tensor

argmax()

argmax(axis=-1)

source

Finds the indices of the maximum values along an axis.

Returns a tensor containing the indices of the maximum values along the specified axis. This is useful for finding the position of the largest element, such as determining predicted classes in classification.

from max.experimental import tensor

# Create a 2x4 tensor
x = tensor.Tensor(
    [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)

# Find argmax along last axis (within each row)
indices = x.argmax(axis=-1)
# Result: [1, 2] (index 1 in first row, index 2 in second row)

# Find argmax over all elements
index = x.argmax(axis=None)
# Result: 6 (flattened index of maximum value 4.2)

Parameters:

axis (int | None) – The axis along which to find the maximum indices. Defaults to -1 (the last axis). If None, finds the index of the maximum value across all elements.

Returns:

A tensor containing the indices of the maximum values.

Return type:

Tensor

broadcast_to()

broadcast_to(shape)

source

Broadcasts the tensor to the specified shape.

Returns a tensor broadcast to the target shape, following NumPy broadcasting semantics. Dimensions of size 1 in the input can be expanded to match larger dimensions in the target shape.

This is equivalent to PyTorch’s torch.broadcast_to() and torch.Tensor.expand().

from max.experimental import tensor

# Create a tensor with shape (3, 1)
x = tensor.Tensor.ones([3, 1])

# Broadcast to (3, 4) - expands the second dimension
y = x.broadcast_to([3, 4])
print(y.shape)  # (3, 4)

# Add a new leading dimension
w = x.broadcast_to([2, 3, 1])
print(w.shape)  # (2, 3, 1)

Parameters:

shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The target shape. Each dimension must either match the input dimension or be broadcastable from size 1.

Returns:

A tensor broadcast to the specified shape.

Return type:

Tensor

buffers

property buffers: tuple[Buffer, ...]

source

The underlying per-shard driver buffers.

Returns one buffer for non-distributed tensors, N buffers for a distributed tensor with N shards.

Raises:

TypeError – If the tensor is unrealized (lazy/symbolic).

cast()

cast(dtype)

source

Casts the tensor to a different data type.

Returns a new tensor with the same values but a different data type. This is useful for type conversions between different numeric types, such as converting float32 to int32 for indexing operations or float32 to bfloat16 for memory-efficient computations.

from max.experimental import tensor
from max.dtype import DType

# Create a float32 tensor
x = tensor.Tensor([1.7, 2.3, 3.9], dtype=DType.float32)
print(x.dtype)  # DType.float32

# Cast to int32 (truncates decimal values)
y = x.cast(DType.int32)
print(y.dtype)  # DType.int32
# Values: [1, 2, 3]

Parameters:

dtype (DType) – The target data type for the tensor.

Returns:

A new tensor with the specified data type, or self if the tensor already has the target dtype.

Return type:

Tensor

clip()

clip(*, min=None, max=None)

source

Clips values outside a range to the boundaries of the range.

from max.experimental import tensor

# Create a 2x4 tensor
x = tensor.Tensor(
    [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)

# Find max along last axis (within each row)
clipped_above = x.clip(max=3.)
# Result: [[1.2, 3., 2.1, 0.8], [2.3, 1.9, 3, 3.]]

clipped_below = x.clip(min=3.)
# Result: [[3., 3.5, 3., 3.], [3., 3., 4.2, 3.]]

Parameters:

Returns:

A tensor containing the values clipped to the specified range.

Return type:

Tensor

constant()

classmethod constant(value, *, dtype=None, device=None)

source

Creates a tensor from a scalar, array, or nested list.

Parameters:

  • value (DLPackArray | Sequence[float | number[Any] | Sequence[Number | NestedArray]] | float | number[Any]) – The constant value for the tensor. Can be a scalar number, a nested Python list, or any DLPack-compatible array.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | None) – The device where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A new tensor containing the constant value(s).

Return type:

Tensor

device

property device: Device

source

Gets the device where the tensor is stored.

Returns the device (CPU or accelerator) where the tensor’s data is located. Raises for distributed tensors that span multiple devices.

Returns:

The device where the tensor is stored.

Return type:

Device

driver_tensor

property driver_tensor: Buffer

source

A pointer to the underlying memory.

Raises if the tensor is unrealized or sharded.

dtype

property dtype: DType

source

Gets the data type of the tensor elements.

Returns:

The data type of the tensor elements.

Return type:

DType

from_dlpack()

classmethod from_dlpack(array)

source

Creates a tensor from a DLPack array.

Constructs a tensor by importing data from any object that supports the DLPack protocol (such as NumPy arrays and PyTorch tensors). This enables zero-copy interoperability with other array libraries.

import numpy as np
from max.experimental import tensor

# Create a NumPy array
np_array = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

# Convert to MAX tensor via DLPack
x = tensor.Tensor.from_dlpack(np_array)

Parameters:

array (DLPackArray) – Any object supporting the DLPack protocol, such as NumPy arrays, PyTorch tensors, or JAX arrays.

Returns:

A new tensor containing the data from the DLPack array.

Return type:

Tensor

from_graph_value()

classmethod from_graph_value(value)

source

Creates a tensor from a graph value.

Constructs a tensor from an existing graph value, which can be either a TensorValue or BufferValue. This is used for converting graph level values into tensor objects. The new tensor is registered as unrealized, backed by the current realization context.

Parameters:

value (Value[Any]) – The graph value to wrap. Can be either a TensorValue or BufferValue from the MAX graph API.

Returns:

A new tensor backed by the provided graph value.

Return type:

Tensor

from_shard_values()

classmethod from_shard_values(shard_values, mapping=None)

source

Creates a tensor from one or more per-shard graph values.

For a single shard value with no mapping, behaves like from_graph_value(). For multiple shard values, a DeviceMapping is required and the result is a distributed tensor.

Parameters:

  • shard_values (Sequence[BufferValue | TensorValue]) – Per-device graph values (TensorValue or BufferValue). One per device in the mesh.
  • mapping (DeviceMapping | None) – Device mapping describing how shards map to mesh devices and their placements. Required when len(shard_values) > 1.

Returns:

A tensor backed by the provided shard values.

Raises:

  • ValueError – If multiple shard values are given without a mapping.
  • TypeError – If any shard value is not a graph value.

Return type:

Tensor

full()

classmethod full(shape, value, *, dtype=None, device=None)

source

Creates a tensor filled with a specified value.

Returns a new tensor with the given shape where all elements are initialized to the specified value. This is useful for creating tensors with uniform values other than zero or one.

from max.experimental import tensor
from max.dtype import DType

# Create a 3x3 tensor filled with 7
x = tensor.Tensor.full((3, 3), value=7, dtype=DType.int32)

# Create a 2x4 tensor filled with pi
y = tensor.Tensor.full((2, 4), value=3.14159)

Parameters:

  • shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape.
  • value (float | number[Any]) – The scalar value to fill the tensor with.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | DeviceMapping | None) – The device or device mapping where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU. Pass a DeviceMapping to create a distributed tensor.

Returns:

A new tensor with the specified shape filled with the given value.

Return type:

Tensor

full_like()

classmethod full_like(input, value)

source

Creates a tensor filled with a value, matching a given tensor’s properties.

Returns a new tensor filled with the specified value that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s full_like and PyTorch’s full_like.

from max.experimental import tensor

# Create a reference tensor
ref = tensor.Tensor.ones([2, 3])

# Create tensor filled with 5.0 matching the reference tensor
x = tensor.Tensor.full_like(ref, value=5.0)

Parameters:

  • input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.
  • value (float | number[Any]) – The scalar value to fill the tensor with.

Returns:

A new tensor filled with the specified value, matching the properties of the input.

Return type:

Tensor

graph_values

property graph_values: tuple[BufferValue | TensorValue, ...]

source

Returns per-shard graph values directly from the realization state.

For unrealized tensors (both distributed and single-device), returns the underlying GraphValue``s (``TensorValue | BufferValue) without wrapping in intermediate Tensor objects.

For realized tensors, creates graph values via __tensorvalue__() on each shard.

This is the primary way to access graph-level shard values for custom dispatch rules and SPMD loops.

is_distributed

property is_distributed: bool

source

Returns True if this tensor spans multiple devices.

item()

item()

source

Gets the scalar value from a single-element tensor.

Extracts and returns the scalar value from a tensor containing exactly one element. The tensor is realized if needed and transferred to CPU before extracting the value.

For replicated distributed tensors, the value is read from the first shard (all shards hold identical data).

Returns:

The scalar value from the tensor. The return type matches the tensor’s dtype (e.g., float for float32, int for int32).

Raises:

  • TypeError – If the tensor contains more than one element.
  • ValueError – If the tensor is distributed and not fully replicated.

Return type:

Any

local_shards

property local_shards: tuple[Tensor, ...]

source

Returns per-device shard views as independent unsharded Tensors.

Each returned Tensor is a lightweight, standalone, unsharded Tensor backed by a single shard’s storage or graph value. They can be passed directly to F.* ops or used as Module parameters.

For realized sharded tensors, each shard wraps one driver.Buffer. For unrealized sharded tensors, each shard wraps one GraphValue from the shared RealizationState. For unsharded tensors, returns a 1-tuple containing self.

mapping

property mapping: DeviceMapping

source

Returns the device mapping describing where this tensor lives.

materialize()

materialize()

source

Gather a distributed tensor into a single local tensor.

Allreduces Partial axes, allgathers Sharded axes, and transfers the result to CPU. Returns self unchanged for non-distributed tensors.

Return type:

Tensor

max()

max(axis=-1)

source

Computes the maximum values along an axis.

Returns a tensor containing the maximum values along the specified axis. This is useful for reduction operations and finding peak values in data.

from max.experimental import tensor

# Create a 2x4 tensor
x = tensor.Tensor(
    [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)

# Find max along last axis (within each row)
row_max = x.max(axis=-1)
# Result: [3.5, 4.2]

# Find max along first axis (within each column)
col_max = x.max(axis=0)
# Result: [2.3, 3.5, 4.2, 3.1]

# Find max over all elements
overall_max = x.max(axis=None)
# Result: 4.2 (maximum value across all elements)

Parameters:

axis (int | None) – The axis along which to compute the maximum. Defaults to -1 (the last axis). If None, computes the maximum across all elements.

Returns:

A tensor containing the maximum values along the specified axis.

Return type:

Tensor

mean()

mean(axis=-1)

source

Computes the mean values along an axis.

Returns a tensor containing the arithmetic mean of values along the specified axis. This is useful for computing averages, normalizing data, or aggregating statistics.

from max.experimental import tensor

# Create a 2x4 tensor
x = tensor.Tensor(
    [[2.0, 4.0, 6.0, 8.0], [1.0, 3.0, 5.0, 7.0]],
)

# Compute mean along last axis (within each row)
row_mean = x.mean(axis=-1)
# Result: [5.0, 4.0] (mean of each row)

# Compute mean along first axis (within each column)
col_mean = x.mean(axis=0)
# Result: [1.5, 3.5, 5.5, 7.5] (mean of each column)

# Compute mean over all elements
overall_mean = x.mean(axis=None)
# Result: 4.5 (mean of all elements)

Parameters:

axis (int | None) – The axis along which to compute the mean. Defaults to -1 (the last axis). If None, computes the mean across all elements.

Returns:

A tensor containing the mean values along the specified axis.

Return type:

Tensor

mesh

property mesh: DeviceMesh

source

Returns the device mesh.

min()

min(axis=-1)

source

Computes the minimum values along an axis.

Returns a tensor containing the minimum values along the specified axis. This is useful for reduction operations and finding the smallest values in data.

from max.experimental import tensor

# Create a 2x4 tensor
x = tensor.Tensor(
    [[1.2, 3.5, 2.1, 0.8], [2.3, 1.9, 4.2, 3.1]],
)

# Find min along last axis (within each row)
row_min = x.min(axis=-1)
# Result: [0.8, 1.9]

# Find min along first axis (within each column)
col_min = x.min(axis=0)
# Result: [1.2, 1.9, 2.1, 0.8]

# Find min over all elements
overall_min = x.min(axis=None)
# Result: 0.8 (minimum value across all elements)

Parameters:

axis (int | None) – The axis along which to compute the minimum. Defaults to -1 (the last axis). If None, computes the minimum across all elements.

Returns:

A tensor containing the minimum values along the specified axis.

Return type:

Tensor

num_elements()

num_elements()

source

Gets the total number of elements in the tensor.

Computes the product of all dimensions in the tensor’s shape to determine the total number of elements.

Returns:

The total number of elements in the tensor.

Return type:

int

num_shards

property num_shards: int

source

Returns the number of shards (1 for an unsharded tensor).

ones()

classmethod ones(shape, *, dtype=None, device=None)

source

Creates a tensor filled with ones.

Returns a new tensor with the specified shape where all elements are initialized to one.

from max.experimental import tensor

# Create a 2x3 tensor of ones
x = tensor.Tensor.ones((2, 3))

Parameters:

  • shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The shape of the output tensor.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | DeviceMapping | None) – The device or device mapping where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A new tensor with the specified shape filled with ones.

Return type:

Tensor

ones_like()

classmethod ones_like(input)

source

Creates a tensor of ones matching a given tensor’s properties.

Returns a new tensor filled with ones that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s ones_like and PyTorch’s ones_like.

from max.experimental import tensor

# Create a reference tensor
ref = tensor.Tensor.zeros([3, 4])

# Create ones tensor matching the reference tensor
x = tensor.Tensor.ones_like(ref)
# Result: 3x4 tensor of ones with dtype float32

Parameters:

input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.

Returns:

A new tensor filled with ones matching the properties of the input.

Return type:

Tensor

permute()

permute(dims)

source

Permutes the dimensions of the tensor.

Returns a tensor with its dimensions reordered according to the specified permutation. This is useful for changing the layout of multi-dimensional data, such as converting between different tensor layout conventions (e.g., from [batch, channels, height, width] to [batch, height, width, channels]).

from max.experimental.tensor import Tensor
from max.dtype import DType

# Create a 3D tensor (batch_size=2, channels=3, length=4)
x = Tensor(
    [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
     [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]],
    dtype=DType.int32,
)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3), Dim(4)]

# Rearrange to (batch, length, channels)
y = x.permute([0, 2, 1])
print(f"Permuted shape: {y.shape}")
# Output: Permuted shape: [Dim(2), Dim(4), Dim(3)]

Parameters:

dims (list[int]) – A list specifying the new order of dimensions. For example, [2, 0, 1] moves dimension 2 to position 0, dimension 0 to position 1, and dimension 1 to position 2.

Returns:

A tensor with permuted dimensions.

Return type:

Tensor

placements

property placements: tuple[Placement, ...]

source

Returns per-axis placement descriptors.

For NamedMapping, this converts to placements on the fly. Raises ConversionError if the spec contains compiler-only annotations.

prod()

prod(axis=-1)

source

Computes the product of values along an axis.

Parameters:

axis (int | None) – The axis along which to compute the product. Defaults to -1 (the last axis). If None, computes the product across all elements.

Returns:

A tensor containing the product along the specified axis.

Return type:

Tensor

range_like()

classmethod range_like(type)

source

Creates a range tensor matching a given type’s properties.

Returns a new tensor containing sequential indices along the last dimension, broadcasted to match the shape of the specified tensor type. Each row (along the last dimension) contains values from 0 to the dimension size minus one. This is useful for creating position indices or coordinate tensors.

from max.experimental import tensor
from max.graph import TensorType
from max.dtype import DType

# Create a reference tensor type with shape (2, 4)
ref_type = TensorType(DType.int32, (2, 4))

# Create range tensor matching the reference type
x = tensor.Tensor.range_like(ref_type)
# Result: [[0, 1, 2, 3],
#          [0, 1, 2, 3]]

Parameters:

type (TensorType) – The tensor type to match. The returned tensor will have the same shape, dtype, and device as this type, with values representing indices along the last dimension.

Returns:

A new tensor with sequential indices broadcasted to match the input type’s shape.

Return type:

Tensor

rank

property rank: int

source

Gets the number of dimensions in the tensor.

Returns the rank (number of dimensions) of the tensor. For example, a scalar has rank 0, a vector has rank 1, and a matrix has rank 2.

Returns:

The number of dimensions in the tensor.

Return type:

int

real

property real: bool

source

Returns True if this tensor is realized (has concrete storage).

For sharded tensors this is all-or-nothing: either every shard is realized (_state is None) or none are.

realize

property realize: Tensor

source

Force the tensor to realize if it is not already.

reshape()

reshape(shape)

source

Reshapes the tensor to a new shape.

Returns a tensor with the same data but a different shape. The total number of elements must remain the same. This is useful for changing tensor dimensions for different operations, such as flattening a multi-dimensional tensor or converting a 1D tensor into a matrix.

from max.experimental import tensor
from max.dtype import DType

# Create a 2x3 tensor
x = tensor.Tensor([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(x.shape)  # (2, 3)

# Flatten to 1D
y = x.reshape((6,))
print(y.shape)  # (6,)
# Values: [1, 2, 3, 4, 5, 6]

Parameters:

shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The desired output shape. Can be a tuple or list of integers. The total number of elements must equal the original tensor’s element count.

Returns:

A reshaped tensor with the specified shape.

Return type:

Tensor

shape

property shape: Shape

source

Gets the global shape of the tensor.

For sharded tensors this returns the logical global shape (not the per-shard shape). If no explicit global shape was set, it is derived from the first shard’s shape, placements, and mesh.

Returns:

The shape of the tensor.

Return type:

Shape

split()

split(split_size_or_sections, axis=0)

source

Splits the tensor into multiple tensors along a given dimension.

This method supports two modes, matching PyTorch’s behavior:

  • If split_size_or_sections is an int, splits into chunks of that size (the last chunk may be smaller if not evenly divisible).
  • If split_size_or_sections is a list of ints, splits into chunks with exactly those sizes (must sum to the dimension size).
from max.experimental import tensor

# Create a 10x4 tensor
x = tensor.Tensor.ones([10, 4])

# Split into chunks of size 3 (last chunk is size 1)
chunks = x.split(3, axis=0)
# Result: 4 tensors with shapes [3,4], [3,4], [3,4], [1,4]

# Split into exact sizes
chunks = x.split([2, 3, 5], axis=0)
# Result: 3 tensors with shapes [2,4], [3,4], [5,4]

Parameters:

  • split_size_or_sections (int | list[int]) – Either an int (chunk size) or a list of ints (exact sizes for each output tensor).
  • axis (int) – The dimension along which to split. Defaults to 0.

Returns:

A list of tensors resulting from the split.

Return type:

list[Tensor]

squeeze()

squeeze(axis)

source

Removes a size-1 dimension from the tensor.

Returns a tensor with the specified size-1 dimension removed. This is useful for removing singleton dimensions from tensors after operations that may have added them.

from max.experimental import tensor

# Create a tensor with a size-1 dimension
x = tensor.Tensor.ones([4, 1, 6])
print(x.shape)  # (4, 1, 6)

# Squeeze out the size-1 dimension
y = x.squeeze(axis=1)
print(y.shape)  # (4, 6)

Parameters:

axis (int) – The dimension to remove from the tensor’s shape. If negative, this indexes from the end of the tensor. The dimension at this axis must have size 1.

Returns:

A tensor with the specified dimension removed.

Return type:

Tensor

Raises:

ValueError – If the dimension at the specified axis is not size 1.

state

property state: RealizationState | None

source

Returns the realization state (unsharded tensors only).

storage

property storage: Buffer | None

source

Returns the single backing buffer (unsharded tensors only).

sum()

sum(axis=-1)

source

Computes the sum of values along an axis.

Returns a tensor containing the sum of values along the specified axis. This is a fundamental reduction operation used for aggregating data, computing totals, and implementing other operations like mean.

from max.experimental import tensor

# Create a 2x3 tensor
x = tensor.Tensor(
    [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
)

# Sum along last axis (within each row)
row_sum = x.sum(axis=-1)
# Result: [6.0, 15.0] (sum of each row)

# Sum along first axis (within each column)
col_sum = x.sum(axis=0)
# Result: [5.0, 7.0, 9.0] (sum of each column)

# Sum over all elements
total = x.sum(axis=None)
# Result: 21.0 (sum of all elements)

Parameters:

axis (int | None) – The axis along which to compute the sum. Defaults to -1 (the last axis). If None, computes the sum across all elements.

Returns:

A tensor containing the sum along the specified axis.

Return type:

Tensor

to()

to(target)

source

Transfers the tensor to a different device, mesh, or mapping.

This method supports three target types:

  1. Device: Transfers a single-device tensor to the target device. For realized tensors, performs a direct driver-level transfer via to(). For unrealized tensors, inserts a transfer_to() op into the computation graph.
  2. DeviceMapping: Reassigns the tensor’s device mesh and placements. For single-device mappings, equivalent to .to(device). For multi-device mappings on an unsharded tensor, distributes the tensor across the mesh using the shard collective.
  3. DeviceMesh: Replaces the device mesh while keeping existing placements. For unsharded tensors targeting a multi-device mesh, creates a fully replicated mapping. For distributed tensors, transfers shards to the new mesh devices.
from max.experimental import tensor
from max.driver import CPU, Accelerator

# Create a tensor on CPU
x = tensor.Tensor.ones((2, 3), device=CPU())
print(x.device)  # CPU

# Transfer to accelerator
y = x.to(Accelerator())
print(y.device)  # Accelerator(0)

# Same-device transfer is a no-op
z = y.to(y.device)
assert z is y

Parameters:

target (Device | DeviceMesh | DeviceMapping) –

The target for the tensor. Can be:

  • Device: Target device for transfer.
  • DeviceMesh: New mesh, keeping existing placements (or fully replicated for unsharded tensors).
  • DeviceMapping: New mesh and placements; triggers shard collective for multi-device.

Returns:

A tensor on the specified target. Returns self if no transfer is needed.

Return type:

Tensor

to_numpy()

to_numpy()

source

Convert this tensor to a NumPy array.

Materializes distributed tensors and transfers to CPU if needed.

Return type:

ndarray[Any, Any]

transpose()

transpose(dim1, dim2)

source

Returns a tensor that is a transposed version of input.

The given dimensions dim1 and dim2 are swapped.

from max.experimental.tensor import Tensor
from max.dtype import DType

# Create a 2x3 matrix
x = Tensor([[1, 2, 3], [4, 5, 6]], dtype=DType.int32)
print(f"Original shape: {x.shape}")
# Output: Original shape: [Dim(2), Dim(3)]
print(x)

# Transpose dimensions 0 and 1 to get a 3x2 matrix
y = x.transpose(0, 1)
print(f"Transposed shape: {y.shape}")
# Output: Transposed shape: [Dim(3), Dim(2)]
print(y)

Parameters:

  • dim1 (int) – The first dimension to be transposed.
  • dim2 (int) – The second dimension to be transposed.

Returns:

A tensor with dimensions dim1 and dim2 swapped.

Return type:

Tensor

type

property type: TensorType

source

Gets the tensor type information.

Returns:

The type information for the tensor.

Return type:

TensorType

Raises:

TypeError – If the tensor is distributed.

unsqueeze()

unsqueeze(axis)

source

Inserts a size-1 dimension into the tensor.

Returns a tensor with a new size-1 dimension inserted at the specified position. This is the inverse of squeeze() and is useful for adding dimensions needed for broadcasting or matrix operations.

from max.experimental import tensor

# Create a 1D tensor
x = tensor.Tensor([1.0, 2.0, 3.0])
print(x.shape)  # (3,)

# Add dimension at the end
y = x.unsqueeze(axis=-1)
print(y.shape)  # (3, 1)

# Add dimension at the beginning
z = x.unsqueeze(axis=0)
print(z.shape)  # (1, 3)

Parameters:

axis (int) – The index at which to insert the new dimension. If negative, indexes relative to 1 plus the rank of the tensor. For example, axis=-1 adds a dimension at the end.

Returns:

A tensor with an additional size-1 dimension.

Return type:

Tensor

zeros()

classmethod zeros(shape, *, dtype=None, device=None)

source

Creates a tensor filled with zeros.

Returns a new tensor with the specified shape where all elements are initialized to zero. The tensor is created with eager execution and automatic compilation.

from max.experimental import tensor

# Create a 2x3 tensor of zeros
x = tensor.Tensor.zeros((2, 3))
# Result: [[0.0, 0.0, 0.0],
#          [0.0, 0.0, 0.0]]

# Create a 1D tensor using default dtype and device
y = tensor.Tensor.zeros((5,))

Parameters:

  • shape (Iterable[int | str | Dim | integer[Any] | TypedAttr]) – The shape of the output tensor. Can be a tuple of integers, a list of integers, or any value that can be converted to a shape.
  • dtype (DType | None) – The data type for the tensor elements. If not specified, defaults to DType.float32 for CPU devices and DType.bfloat16 for accelerator devices.
  • device (Device | DeviceMapping | None) – The device or device mapping where the tensor will be allocated. If not specified, defaults to an accelerator if available, otherwise CPU.

Returns:

A new tensor with the specified shape filled with zeros.

Return type:

Tensor

zeros_like()

classmethod zeros_like(input)

source

Creates a tensor of zeros matching a given tensor’s properties.

Returns a new tensor filled with zeros that matches the shape, data type, and device of the input tensor. This behaves like NumPy’s zeros_like and PyTorch’s zeros_like.

from max.experimental import tensor

# Create a reference tensor
ref = tensor.Tensor.ones([3, 4])

# Create zeros tensor matching the reference tensor
x = tensor.Tensor.zeros_like(ref)
# Result: 3x4 tensor of zeros with dtype float32

Parameters:

input (Tensor | TensorType) – The tensor or tensor type to match. The returned tensor will have the same shape, dtype, and device as this input.

Returns:

A new tensor filled with zeros matching the properties of the input.

Return type:

Tensor