Python module
embedding
The embedding
module provides classes for mapping integer indices (like
token IDs) to dense vector representations. These embedding operations are
fundamental building blocks for natural language processing, recommendation
systems, and other tasks involving discrete tokens.
Embedding
: Basic embedding lookup table for simple use casesEmbeddingV2
: Enhanced embedding with device placement control and improved memory managementVocabParallelEmbedding
: Distributed embedding that shards the vocabulary across multiple devices for large embedding tables
Here’s an example demonstrating how to use embeddings:
import max.nn as nn
from max.graph import Graph, ops, DeviceRef
from max.dtype import DType
import numpy as np
with Graph(name="embedding_example") as graph:
# Define dimensions
batch_size = 4
seq_length = 16
vocab_size = 10000
hidden_dim = 256
# Create input tensor of token indices
input_data = np.random.randint(0, vocab_size, (batch_size, seq_length), dtype=np.int32)
input_indices = ops.constant(input_data, dtype=DType.int32)
# Create embedding layer
embedding = nn.EmbeddingV2(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="token_embeddings"
)
# Look up embeddings for input indices
embeddings = embedding(input_indices)
print(f"Embedding output shape: {embeddings.shape}")
# Embedding output shape: [Dim(4), Dim(16), Dim(256)]
import max.nn as nn
from max.graph import Graph, ops, DeviceRef
from max.dtype import DType
import numpy as np
with Graph(name="embedding_example") as graph:
# Define dimensions
batch_size = 4
seq_length = 16
vocab_size = 10000
hidden_dim = 256
# Create input tensor of token indices
input_data = np.random.randint(0, vocab_size, (batch_size, seq_length), dtype=np.int32)
input_indices = ops.constant(input_data, dtype=DType.int32)
# Create embedding layer
embedding = nn.EmbeddingV2(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="token_embeddings"
)
# Look up embeddings for input indices
embeddings = embedding(input_indices)
print(f"Embedding output shape: {embeddings.shape}")
# Embedding output shape: [Dim(4), Dim(16), Dim(256)]
Embedding
class max.nn.embedding.Embedding(weights: 'TensorValueLike')
weights
weights*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*
EmbeddingV2
class max.nn.embedding.EmbeddingV2(vocab_size: int, hidden_dim: int, dtype: DType, device: DeviceRef | None = None, quantization_encoding: QuantizationEncoding | None = None, name: str | None = None)
A lookup table for embedding integer indices into dense vectors.
This layer maps each integer index to a dense vector of fixed size. Embedding weights are stored on the CPU but are moved to the specified device during the model init phase.
Example:
embedding_layer = EmbeddingV2(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="embeddings",
)
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
embedding_layer = EmbeddingV2(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="embeddings",
)
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
Initializes the embedding layer with the given arguments.
-
Parameters:
- vocab_size – The number of unique items in the vocabulary.
Indices must be in the range
[0, vocab_size)
. - hidden_dim – The dimensionality of each embedding vector.
- dtype – The data type of the embedding weights.
- device – The device where embedding lookups are executed. Model init transfers the initially CPU-resident weights to this device.
- name – The name identifier for the embedding weight matrix.
- vocab_size – The number of unique items in the vocabulary.
Indices must be in the range
device
device*: DeviceRef | None*
The device on which embedding lookup is performed.
weight
weight*: Weight*
The embedding weight matrix stored on the CPU.
Model init moves weights to the device specified in device
.
VocabParallelEmbedding
class max.nn.embedding.VocabParallelEmbedding(vocab_size: int, hidden_dim: int, dtype: DType, devices: list[max.graph.type.DeviceRef], quantization_encoding: QuantizationEncoding | None = None, name: str | None = None)
A lookup table for embedding integer indices into dense vectors.
This layer works like nn.Embedding except the embedding table is sharded on the vocabulary dimension across all devices.
Example:
embedding_layer = VocabParallelEmbedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=[DeviceRef.GPU(0), DeviceRef.GPU(1)],
name="embeddings",
)
# Token indices of shape: [batch, ..., num_indices].
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
embedding_layer = VocabParallelEmbedding(
vocab_size=1000,
hidden_dim=256,
dtype=DType.float32,
device=[DeviceRef.GPU(0), DeviceRef.GPU(1)],
name="embeddings",
)
# Token indices of shape: [batch, ..., num_indices].
token_indices: TensorValueLike
embeddings = embedding_layer(token_indices)
-
Parameters:
- vocab_size – The number of unique items in the vocabulary.
Indices must be in the range
[0, vocab_size)
. - hidden_dim – The dimensionality of each embedding vector.
- dtype – The data type of the embedding weights.
- devices – The devices where embedding lookups are executed. Model init transfers the initially CPU-resident weights to this device.
- name – The name identifier for the embedding weight matrix.
- vocab_size – The number of unique items in the vocabulary.
Indices must be in the range
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!