Python class
Embedding
Embedding
class max.nn.Embedding(vocab_size, *, dim=None, dims=None)
A vector embedding.
An embedding can be thought of as a lookup table for vectors by index. Given an input tensor of indices into the embedding, the result of the embedding lookup is a tensor of the same shape, but with each index replaced by the value of the vector in that location in the embedding table.
The common case for embeddings is a 1-dimensional embedding:
from max.dtype import DType
from max.tensor import Tensor
from max.nn import Embedding
embedding = Embedding(vocab_size=1000, dim=128)
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 128]However they just as easily support multi-dimensional embeddings:
from max.dtype import DType
from max.tensor import Tensor
from max.nn import Embedding
embedding = Embedding(vocab_size=1000, dims=[16, 128])
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 16, 128]Creates a randomly initialized embedding of the specified size.
-
Parameters:
-
- vocab_size (DimLike) – The number of elements in the lookup table. Indices outside the range of [0, index_size) are illegal in the resulting embedding operation.
- dim (DimLike | None) – The embedding dimension if there is exactly one. Equivalent to dims=[dim].
- dims (ShapeLike | None) – For specifying multi-dimensional embeddings. The shape of the vectors in the embedding.
dim
property dim: Dim
The dimension of the vectors in the embedding (for a 1d embedding).
Raises: For 0- or >1-dimensional embeddings.
dims
The dimensions of the vectors in the embedding.
forward()
forward(indices)
Applies the vector embedding to the input tensor of indices.
-
Parameters:
-
indices (Tensor) – An integer-valued tensor. Values must be in the range [0, vocab_size) for the embedding.
-
Returns:
-
A dense tensor made by looking up each index in the vector embedding. For an input of shape
(*batch, indices)and an embedding of shape(vocab_size, *dims), the result will have shape(*batch, indices, *dims). -
Return type:
vocab_size
property vocab_size: Dim
The vocab size of the embedding.
Indices outside the range of [0, index_size) are illegal.
weight
weight: Tensor
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!