Skip to main content

Python class

Embedding

Embedding

class max.nn.Embedding(vocab_size, *, dim=None, dims=None)

A vector embedding.

An embedding can be thought of as a lookup table for vectors by index. Given an input tensor of indices into the embedding, the result of the embedding lookup is a tensor of the same shape, but with each index replaced by the value of the vector in that location in the embedding table.

The common case for embeddings is a 1-dimensional embedding:

from max.dtype import DType
from max.tensor import Tensor
from max.nn import Embedding

embedding = Embedding(vocab_size=1000, dim=128)
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 128]

However they just as easily support multi-dimensional embeddings:

from max.dtype import DType
from max.tensor import Tensor
from max.nn import Embedding

embedding = Embedding(vocab_size=1000, dims=[16, 128])
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 16, 128]

Creates a randomly initialized embedding of the specified size.

Parameters:

  • vocab_size (DimLike) – The number of elements in the lookup table. Indices outside the range of [0, index_size) are illegal in the resulting embedding operation.
  • dim (DimLike | None) – The embedding dimension if there is exactly one. Equivalent to dims=[dim].
  • dims (ShapeLike | None) – For specifying multi-dimensional embeddings. The shape of the vectors in the embedding.

dim

property dim: Dim

The dimension of the vectors in the embedding (for a 1d embedding).

Raises: For 0- or >1-dimensional embeddings.

dims

property dims: Sequence[Dim]

The dimensions of the vectors in the embedding.

forward()

forward(indices)

Applies the vector embedding to the input tensor of indices.

Parameters:

indices (Tensor) – An integer-valued tensor. Values must be in the range [0, vocab_size) for the embedding.

Returns:

A dense tensor made by looking up each index in the vector embedding. For an input of shape (*batch, indices) and an embedding of shape (vocab_size, *dims), the result will have shape (*batch, indices, *dims).

Return type:

Tensor

vocab_size

property vocab_size: Dim

The vocab size of the embedding.

Indices outside the range of [0, index_size) are illegal.

weight

weight: Tensor

Was this page helpful?