Skip to main content
Log in

Mojo struct

TensorDict

struct TensorDict

A collection of keyed Tensor values used with checkpoint files.

This is the type accepted by save() and returned by load().

For example:

from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")

def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")

def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")

Implemented traits

AnyType, Sized, UnknownDestructibility

Methods

__init__

__init__(out self)

__copyinit__

__copyinit__(out self, existing: Self)

Copies a dictionary.

Args:

  • existing (Self): The existing dict.

__moveinit__

__moveinit__(out self, owned existing: Self)

Moves data of an existing dictionary into a new one.

Args:

  • existing (Self): The existing dict.

__setitem__

__setitem__[T: DType](mut self, key: String, value: Tensor[T])

Supports setting items with the bracket accessor.

For example:

tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)
tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)

Args:

  • key (String): The key to associate with the specified value.
  • value (Tensor[T]): The data to store in the dictionary.

set

set[T: DType](mut self, key: String, value: Tensor[T])

Adds or updates a tensor in the dictionary.

Args:

  • key (String): The name of the tensor.
  • value (Tensor[T]): The tensor to add.

get

get[type: DType](self, key: String) -> Tensor[type]

Gets a tensor from the dictionary.

Currently, this returns a copy of the tensor. For better performance, use Dict.pop().

This method may change in the future to return an immutable reference instead of a mutable tensor copy.

Args:

  • key (String): The name of the tensor.

Returns:

A copy of the tensor.

pop

pop[type: DType](mut self, key: String) -> Tensor[type]

Removes a tensor from the dictionary.

This function moves the Tensor pointer out of the dictionary and returns it to the caller.

Args:

  • key (String): The name of the tensor.

Returns:

The tensor.

__len__

__len__(self) -> Int

items

items(ref self) -> _DictEntryIter[String, _CheckpointTensor, self_is_origin._items]

Gets an iterable view of all elements in the dictionary.

keys

keys(ref self) -> _DictKeyIter[String, _CheckpointTensor, self_is_origin._items]

Gets an iterable view of all keys in the dictionary.

__iter__

__iter__(ref self) -> _DictKeyIter[String, _CheckpointTensor, self_is_origin._items]

__str__

__str__(self) -> String