Skip to main content
Log in

Mojo struct

TensorDict

A collection of keyed Tensor values used with checkpoint files.

This is the type accepted by save() and returned by load().

For example:

from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")

def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")

def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")

Implemented traits

AnyType, Sized

Methods

__init__

__init__(inout self: Self)

__copyinit__

__copyinit__(inout self: Self, existing: Self)

Copies a dictionary.

Args:

  • existing (Self): The existing dict.

__moveinit__

__moveinit__(inout self: Self, owned existing: Self)

Moves data of an existing dictionary into a new one.

Args:

  • existing (Self): The existing dict.

__setitem__

__setitem__[T: DType](inout self: Self, key: String, value: Tensor[T])

Supports setting items with the bracket accessor.

For example:

tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)
tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)

Args:

  • key (String): The key to associate with the specified value.
  • value (Tensor[T]): The data to store in the dictionary.

set

set[T: DType](inout self: Self, key: String, value: Tensor[T])

Adds or updates a tensor in the dictionary.

Args:

  • key (String): The name of the tensor.
  • value (Tensor[T]): The tensor to add.

get

get[type: DType](self: Self, key: String) -> Tensor[$0]

Gets a tensor from the dictionary.

Currently, this returns a copy of the tensor. For better performance, use Dict.pop().

This method may change in the future to return an immutable reference instead of a mutable tensor copy.

Args:

  • key (String): The name of the tensor.

Returns:

A copy of the tensor.

pop

pop[type: DType](inout self: Self, key: String) -> Tensor[$0]

Removes a tensor from the dictionary.

This function moves the Tensor pointer out of the dictionary and returns it to the caller.

Args:

  • key (String): The name of the tensor.

Returns:

The tensor.

__len__

__len__(self: Self) -> Int

items

items(ref [self_is_lifetime] self: Self) -> _DictEntryIter[$0, String, _CheckpointTensor, $1._items, 1]

Gets an iterable view of all elements in the dictionary.

keys

keys(ref [self_is_lifetime] self: Self) -> _DictKeyIter[$0, String, _CheckpointTensor, $1._items, 1]

Gets an iterable view of all keys in the dictionary.

__iter__

__iter__(ref [self_is_lifetime] self: Self) -> _DictKeyIter[$0, String, _CheckpointTensor, $1._items, 1]

__str__

__str__(self: Self) -> String

Was this page helpful?