IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

RotaryEmbedding

RotaryEmbedding

class max.experimental.nn.rope.RotaryEmbedding(weight)

source

Bases: Module

Applies Rotary Positional Embeddings (RoPE) to input tensors.

RoPE encodes positional information using complex-valued rotations applied to query and key vectors in attention. This encoding is relative, allowing models to better generalize to sequences longer than those seen during training.

See “RoFormer: Enhanced Transformer with Rotary Position Embedding” (https://arxiv.org/abs/2104.09864)

from max.experimental.nn.rope import RotaryEmbedding, positional_embedding
from max.experimental.tensor import Tensor

rope = RotaryEmbedding(
    weight=positional_embedding(
        dim=128,
        base=10000.0,
        max_sequence_length=2048
    )
)

# Apply to query or key tensors in attention
# Shape: (batch, seq_len, num_heads, head_dim)
query = Tensor.randn([4, 128, 12, 128])
query_with_rope = rope(query, start_pos=0)

print(query_with_rope.shape)  # (4, 128, 12, 128)
# Positional information now encoded in the rotation of query vectors

Parameters:

weight (Tensor)

dim

property dim: int

source

Returns the embedding dimension.

forward()

forward(x, start_pos=0)

source

Defines the computation performed by the module.

Users must override this method in their subclass to define the module’s computation.

Parameters:

  • *args – Positional arguments for the computation.
  • **kwargs – Keyword arguments for the computation.
  • x (Tensor)
  • start_pos (int | str | Dim | integer | TypedAttr)

Returns:

The result of applying the module to the input.

Raises:

NotImplementedError – If the subclass does not override this method.

Return type:

Tensor

max_sequence_length

property max_sequence_length: int

source

Returns the maximum sequence length.

weight

weight: Tensor

source