Python class
VocabParallelEmbedding
VocabParallelEmbedding
class max.nn.VocabParallelEmbedding(vocab_size, hidden_dim, dtype, devices, quantization_encoding=None, name=None)
Bases: Module
A lookup table for embedding integer indices into dense vectors.
This layer works like Embedding except the embedding table is
sharded on the vocabulary dimension across all devices. When called,
VocabParallelEmbedding accepts a TensorValueLike of
integer indices along with signal buffers for cross-device communication
and returns a list of TensorValue tensors (one per
device) containing the corresponding embedding vectors.
embedding_layer = VocabParallelEmbedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=[DeviceRef.GPU(0), DeviceRef.GPU(1)],
name="embeddings",
)
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)Initializes the vocab-parallel embedding layer.
-
Parameters:
-
- vocab_size (int) – The number of unique items in the vocabulary.
Indices must be in the range
[0, vocab_size). - hidden_dim (int) – The dimensionality of each embedding vector.
- dtype (DType) – The data type of the embedding weights.
- devices (list[DeviceRef]) – The devices where embedding lookups are executed. Model init transfers the initially CPU-resident weights to this device.
- quantization_encoding (QuantizationEncoding | None) – Optional quantization encoding for the weights.
- name (str | None) – The name identifier for the embedding weight matrix.
- vocab_size (int) – The number of unique items in the vocabulary.
Indices must be in the range
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!