Skip to main content

Python class

RotaryEmbedding

RotaryEmbedding

class max.experimental.nn.rope.RotaryEmbedding(weight)

source

Bases: Module

Applies Rotary Positional Embeddings (RoPE) to input tensors.

RoPE encodes positional information using complex-valued rotations applied to query and key vectors in attention. This encoding is relative, allowing models to better generalize to sequences longer than those seen during training.

See “RoFormer: Enhanced Transformer with Rotary Position Embedding” (https://arxiv.org/abs/2104.09864)

from max.experimental.nn.rope import RotaryEmbedding, positional_embedding
from max.experimental.tensor import Tensor

rope = RotaryEmbedding(
    weight=positional_embedding(
        dim=128,
        base=10000.0,
        max_sequence_length=2048
    )
)

# Apply to query or key tensors in attention
# Shape: (batch, seq_len, num_heads, head_dim)
query = Tensor.randn([4, 128, 12, 128])
query_with_rope = rope(query, start_pos=0)

print(query_with_rope.shape)  # (4, 128, 12, 128)
# Positional information now encoded in the rotation of query vectors

Parameters:

weight (Tensor)

dim

property dim: int

source

Returns the embedding dimension.

forward()

forward(x, start_pos=0)

source

Applies rotary positional embeddings (RoPE) to x.

seq_len is inferred from the shape of x.

Parameters:

  • x (Tensor) – Activation tensor with shape (batch, seq_len, n_kv_heads, head_dim). x is interpreted as a complex number valued tensor where the head_dim dimension is alternating pairs of (real, imaginary) parts.
  • start_pos (int | str | Dim | integer[Any] | TypedAttr) – starting position of input tensor, defaults to 0 if None

Returns:

Input activation tensor with rotary positional embeddings applied and the same shape as x.

Return type:

Tensor

max_sequence_length

property max_sequence_length: int

source

Returns the maximum sequence length.

weight

weight: Tensor

source