Skip to main content

Python class

RotaryEmbedding

RotaryEmbedding

class max.nn.RotaryEmbedding(dim, n_heads, theta, max_seq_len, head_dim=None, _freqs_cis=None, interleaved=True)

source

Bases: Module

Applies Rotary Position Embedding (RoPE) to transformer activations.

When called, RotaryEmbedding computes the frequency tensor for complex exponentials and applies it to input tensors. It accepts a TensorValueLike of shape (batch, seq_len, n_kv_heads, head_dim) along with optional start_pos and seq_len arguments and returns a TensorValue of the same shape with rotary positional embeddings applied. RotaryEmbedding supports both interleaved and non-interleaved RoPE variants.

Parameters:

  • dim (int) – The model’s hidden dimension.
  • n_heads (int) – The number of attention heads.
  • theta (float) – The base for computing RoPE frequencies. Controls the frequency scaling of the sinusoidal components.
  • max_seq_len (int) – The maximum sequence length for model input.
  • head_dim (int) – The per-head dimension. Defaults to dim // n_heads if None.
  • _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – A pre-computed frequency tensor. Defaults to None.
  • interleaved (bool) – Whether to apply RoPE using interleaved complex representation. Defaults to True.

compute_scale()

compute_scale(user_scale=None)

source

Returns the attention scale factor.

Parameters:

user_scale (float | None) – A custom scale factor. Defaults to None, in which case the scale is computed as 1 / sqrt(head_dim).

Returns:

The attention scale factor.

Return type:

float

dim

dim: int

source

The model’s hidden dimension.

freqs_cis

property freqs_cis: TensorValue

source

The reshaped frequency tensor used for applying RoPE.

Retrieves the base frequency tensor from freqs_cis_base() and reshapes it from (max_seq_len * 2, head_dim // 2, 2) to (max_seq_len * 2, head_dim).

freqs_cis_base()

freqs_cis_base()

source

Computes the frequency cosine-sine tensor for Rotary Position Embedding.

Scales the tensor using the theta parameter. Based on RoFormer: Enhanced Transformer with Rotary Position Embedding.

Returns:

The frequency tensor with shape (max_seq_len * 2, head_dim // 2, 2).

Return type:

TensorValue

head_dim

head_dim: int

source

The per-head dimension. Equal to dim // n_heads if not specified.

interleaved

interleaved: bool = True

source

Whether to apply RoPE using interleaved complex representation.

max_seq_len

max_seq_len: int

source

The maximum sequence length for model input.

n_heads

n_heads: int

source

The number of attention heads.

theta

theta: float

source

The base for computing RoPE frequencies. Controls the frequency scaling of the sinusoidal components.