Mojo struct
TensorDict
struct TensorDict
A collection of keyed Tensor
values used with checkpoint files.
This is the type accepted by
save()
and
returned by
load()
.
For example:
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape
def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")
def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape
def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.maxckpt")
def read_from_disk():
tensors = load("/path/to/checkpoint.maxckpt")
x = tensors.get[DType.int32]("x")
Implemented traits
AnyType
,
Sized
,
UnknownDestructibility
Methods
__init__
__init__(out self)
__copyinit__
__copyinit__(out self, existing: Self)
Copies a dictionary.
Args:
- existing (
Self
): The existing dict.
__moveinit__
__moveinit__(out self, owned existing: Self)
Moves data of an existing dictionary into a new one.
Args:
- existing (
Self
): The existing dict.
__setitem__
__setitem__[T: DType](mut self, key: String, value: Tensor[T])
Supports setting items with the bracket accessor.
For example:
tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)
tensors = TensorDict()
tensors["x"] = Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4)
Args:
- key (
String
): The key to associate with the specified value. - value (
Tensor[T]
): The data to store in the dictionary.
set
set[T: DType](mut self, key: String, value: Tensor[T])
Adds or updates a tensor in the dictionary.
Args:
- key (
String
): The name of the tensor. - value (
Tensor[T]
): The tensor to add.
get
get[type: DType](self, key: String) -> Tensor[type]
Gets a tensor from the dictionary.
Currently, this returns a copy of the tensor. For better performance,
use Dict.pop()
.
This method may change in the future to return an immutable reference instead of a mutable tensor copy.
Args:
- key (
String
): The name of the tensor.
Returns:
A copy of the tensor.
pop
pop[type: DType](mut self, key: String) -> Tensor[type]
Removes a tensor from the dictionary.
This function moves the Tensor pointer out of the dictionary and returns it to the caller.
Args:
- key (
String
): The name of the tensor.
Returns:
The tensor.
__len__
__len__(self) -> Int
items
items(ref self) -> _DictEntryIter[String, _CheckpointTensor, self_is_origin._items]
Gets an iterable view of all elements in the dictionary.
keys
keys(ref self) -> _DictKeyIter[String, _CheckpointTensor, self_is_origin._items]
Gets an iterable view of all keys in the dictionary.
__iter__
__iter__(ref self) -> _DictKeyIter[String, _CheckpointTensor, self_is_origin._items]
__str__
__str__(self) -> String
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!