Skip to main content

Python class

Embedding

Embedding

class max.nn.Embedding(vocab_size, hidden_dim, dtype, device, quantization_encoding=None, name=None)

source

Bases: Module

A lookup table for embedding integer indices into dense vectors.

When called, Embedding maps each integer index to a dense vector of fixed size. It accepts a TensorValueLike of integer indices with shape (batch, ..., num_indices) and returns a TensorValue of shape (batch, ..., num_indices, hidden_dim) containing the corresponding embedding vectors.

Embedding weights are stored on the CPU but are moved to the specified device during model initialization.

embedding_layer = Embedding(
    vocab_size=1000,
    hidden_dim=256,
    dtype=DType.float32,
    device=DeviceRef.GPU(),
    name="embeddings",
)

token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)

Initializes the embedding layer with the given arguments.

Parameters:

  • vocab_size (int) – The number of unique items in the vocabulary. Indices must be in the range [0, vocab_size).
  • hidden_dim (int) – The dimensionality of each embedding vector.
  • dtype (DType) – The data type of the embedding weights.
  • device (DeviceRef) – The device where embedding lookups are executed. Model init transfers the initially CPU-resident weights to this device.
  • quantization_encoding (QuantizationEncoding | None) – Optional quantization encoding for the weights.
  • name (str | None) – The name identifier for the embedding weight matrix.

device

device: DeviceRef

source

The device on which embedding lookup is performed.

weight

weight: Weight

source

The embedding weight matrix stored on the CPU. Model init moves weights to the device specified in device.