Python class
Embedding
Embedding
class max.nn.Embedding(vocab_size, hidden_dim, dtype, device, quantization_encoding=None, name=None)
Bases: Module
A lookup table for embedding integer indices into dense vectors.
When called, Embedding maps each integer index to a dense vector of
fixed size. It accepts a TensorValueLike of integer
indices with shape (batch, ..., num_indices) and returns a
TensorValue of shape (batch, ..., num_indices, hidden_dim) containing the corresponding embedding vectors.
Embedding weights are stored on the CPU but are moved to the specified device during model initialization.
embedding_layer = Embedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="embeddings",
)
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)Initializes the embedding layer with the given arguments.
-
Parameters:
-
- vocab_size (int) – The number of unique items in the vocabulary.
Indices must be in the range
[0, vocab_size). - hidden_dim (int) – The dimensionality of each embedding vector.
- dtype (DType) – The data type of the embedding weights.
- device (DeviceRef) – The device where embedding lookups are executed. Model init transfers the initially CPU-resident weights to this device.
- quantization_encoding (QuantizationEncoding | None) – Optional quantization encoding for the weights.
- name (str | None) – The name identifier for the embedding weight matrix.
- vocab_size (int) – The number of unique items in the vocabulary.
Indices must be in the range
device
device: DeviceRef
The device on which embedding lookup is performed.
weight
weight: Weight
The embedding weight matrix stored on the CPU.
Model init moves weights to the device specified in device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!