Skip to main content

struct

TensorDict

A collection of keyed Tensor values used with checkpoint files.

This is the type accepted by save() and returned by load().

For example:

from max.graph.checkpoint import load, save, TensorDict
from tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")

def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")

Implemented traits

AnyType, Sized

Methods

__init__

__init__(inout self: Self)

__copyinit__

__copyinit__(inout self: Self, existing: Self)

Copies a dictionary.

Args:

  • existing (Self): The existing dict.

__moveinit__

__moveinit__(inout self: Self, owned existing: Self)

Moves data of an existing dictionary into a new one.

Args:

  • existing (Self): The existing dict.

__setitem__

__setitem__[T: DType](inout self: Self, key: String, value: Tensor[T])

Supports setting items with the bracket accessor.

For example:

tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)

Args:

  • key (String): The key to associate with the specified value.
  • value (Tensor[T]): The data to store in the dictionary.

set

set[T: DType](inout self: Self, key: String, value: Tensor[T])

Adds or updates a tensor in the dictionary.

Args:

  • key (String): The name of the tensor.
  • value (Tensor[T]): The tensor to add.

get

get[type: DType](self: Self, key: String) -> Tensor[$0]

Gets a tensor from the dictionary.

Currently, this returns a copy of the tensor. For better performance, use Dict.pop().

This method may change in the future to return an immutable reference instead of a mutable tensor copy.

Args:

  • key (String): The name of the tensor.

Returns:

A copy of the tensor.

pop

pop[type: DType](inout self: Self, key: String) -> Tensor[$0]

Removes a tensor from the dictionary.

This function moves the Tensor pointer out of the dictionary and returns it to the caller.

Args:

  • key (String): The name of the tensor.

Returns:

The tensor.

items

items(self: Reference[TensorDict, is_mutable, lifetime, 0]) -> _DictEntryIter[String, _CheckpointTensor, $0, $1, 1]

Gets an iterable view of all elements in the dictionary.

keys

keys(self: Reference[TensorDict, is_mutable, lifetime, 0]) -> _DictKeyIter[String, _CheckpointTensor, $0, $1, 1]

Gets an iterable view of all keys in the dictionary.