Python class
RotaryEmbedding
RotaryEmbedding
class max.experimental.nn.rope.RotaryEmbedding(weight)
Bases: Module
Applies Rotary Positional Embeddings (RoPE) to input tensors.
RoPE encodes positional information using complex-valued rotations applied to query and key vectors in attention. This encoding is relative, allowing models to better generalize to sequences longer than those seen during training.
See “RoFormer: Enhanced Transformer with Rotary Position Embedding” (https://arxiv.org/abs/2104.09864)
from max.experimental.nn.rope import RotaryEmbedding, positional_embedding
from max.experimental.tensor import Tensor
rope = RotaryEmbedding(
weight=positional_embedding(
dim=128,
base=10000.0,
max_sequence_length=2048
)
)
# Apply to query or key tensors in attention
# Shape: (batch, seq_len, num_heads, head_dim)
query = Tensor.randn([4, 128, 12, 128])
query_with_rope = rope(query, start_pos=0)
print(query_with_rope.shape) # (4, 128, 12, 128)
# Positional information now encoded in the rotation of query vectors-
Parameters:
-
weight (Tensor)
dim
property dim: int
Returns the embedding dimension.
forward()
forward(x, start_pos=0)
Applies rotary positional embeddings (RoPE) to x.
seq_len is inferred from the shape of x.
-
Parameters:
-
- x (Tensor) – Activation tensor with shape (batch, seq_len, n_kv_heads, head_dim). x is interpreted as a complex number valued tensor where the head_dim dimension is alternating pairs of (real, imaginary) parts.
- start_pos (int | str | Dim | integer[Any] | TypedAttr) – starting position of input tensor, defaults to 0 if None
-
Returns:
-
Input activation tensor with rotary positional embeddings applied and the same shape as x.
-
Return type:
max_sequence_length
property max_sequence_length: int
Returns the maximum sequence length.
weight
weight: Tensor
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!