Skip to main content
Log in

Mojo struct

TensorDict

A collection of keyed Tensor values used with checkpoint files.

This is the type accepted by save() and returned by load().

For example:

from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")

def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")

def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")

Implemented traits​

AnyType, Sized

Methods​

__init__​

__init__(inout self: Self)

__copyinit__​

__copyinit__(inout self: Self, existing: Self)

Copies a dictionary.

Args:

  • ​existing (Self): The existing dict.

__moveinit__​

__moveinit__(inout self: Self, owned existing: Self)

Moves data of an existing dictionary into a new one.

Args:

  • ​existing (Self): The existing dict.

__setitem__​

__setitem__[T: DType](inout self: Self, key: String, value: Tensor[T])

Supports setting items with the bracket accessor.

For example:

tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)
tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)

Args:

  • ​key (String): The key to associate with the specified value.
  • ​value (Tensor[T]): The data to store in the dictionary.

set​

set[T: DType](inout self: Self, key: String, value: Tensor[T])

Adds or updates a tensor in the dictionary.

Args:

  • ​key (String): The name of the tensor.
  • ​value (Tensor[T]): The tensor to add.

get​

get[type: DType](self: Self, key: String) -> Tensor[$0]

Gets a tensor from the dictionary.

Currently, this returns a copy of the tensor. For better performance, use Dict.pop().

This method may change in the future to return an immutable reference instead of a mutable tensor copy.

Args:

  • ​key (String): The name of the tensor.

Returns:

A copy of the tensor.

pop​

pop[type: DType](inout self: Self, key: String) -> Tensor[$0]

Removes a tensor from the dictionary.

This function moves the Tensor pointer out of the dictionary and returns it to the caller.

Args:

  • ​key (String): The name of the tensor.

Returns:

The tensor.

__len__​

__len__(self: Self) -> Int

items​

items(ref [self_is_lifetime] self: Self) -> _DictEntryIter[$0, String, _CheckpointTensor, $1._items, 1]

Gets an iterable view of all elements in the dictionary.

keys​

keys(ref [self_is_lifetime] self: Self) -> _DictKeyIter[$0, String, _CheckpointTensor, $1._items, 1]

Gets an iterable view of all keys in the dictionary.

__iter__​

__iter__(ref [self_is_lifetime] self: Self) -> _DictKeyIter[$0, String, _CheckpointTensor, $1._items, 1]

__str__​

__str__(self: Self) -> String

Was this page helpful?