struct
TensorDict
A collection of keyed Tensor
values used with checkpoint files.
This is the type accepted by
save()
and
returned by
load()
.
For example:
from max.graph.checkpoint import load, save, TensorDict
from tensor import Tensor, TensorShape
def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")
def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")
Implemented traits
AnyType
,
Sized
Methods
__init__
__init__(inout self: Self)
__copyinit__
__copyinit__(inout self: Self, existing: Self)
Copies a dictionary.
Args:
- existing (
Self
): The existing dict.
__moveinit__
__moveinit__(inout self: Self, owned existing: Self)
Moves data of an existing dictionary into a new one.
Args:
- existing (
Self
): The existing dict.
__setitem__
__setitem__[T: DType](inout self: Self, key: String, value: Tensor[T])
Supports setting items with the bracket accessor.
For example:
tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)
Args:
- key (
String
): The key to associate with the specified value. - value (
Tensor[T]
): The data to store in the dictionary.
set
set[T: DType](inout self: Self, key: String, value: Tensor[T])
Adds or updates a tensor in the dictionary.
Args:
- key (
String
): The name of the tensor. - value (
Tensor[T]
): The tensor to add.
get
get[type: DType](self: Self, key: String) -> Tensor[$0]
Gets a tensor from the dictionary.
Currently, this returns a copy of the tensor. For better performance,
use Dict.pop()
.
This method may change in the future to return an immutable reference instead of a mutable tensor copy.
Args:
- key (
String
): The name of the tensor.
Returns:
A copy of the tensor.
pop
pop[type: DType](inout self: Self, key: String) -> Tensor[$0]
Removes a tensor from the dictionary.
This function moves the Tensor pointer out of the dictionary and returns it to the caller.
Args:
- key (
String
): The name of the tensor.
Returns:
The tensor.
items
items(self: Reference[TensorDict, is_mutable, lifetime, 0]) -> _DictEntryIter[String, _CheckpointTensor, $0, $1, 1]
Gets an iterable view of all elements in the dictionary.
keys
keys(self: Reference[TensorDict, is_mutable, lifetime, 0]) -> _DictKeyIter[String, _CheckpointTensor, $0, $1, 1]
Gets an iterable view of all keys in the dictionary.