Python class
RotaryEmbedding
RotaryEmbedding
class max.nn.RotaryEmbedding(dim, n_heads, theta, max_seq_len, head_dim=None, _freqs_cis=None, interleaved=True)
Bases: Module
Applies Rotary Position Embedding (RoPE) to transformer activations.
When called, RotaryEmbedding computes the frequency tensor for complex
exponentials and applies it to input tensors. It accepts a
TensorValueLike of shape (batch, seq_len, n_kv_heads, head_dim) along with optional start_pos and seq_len arguments
and returns a TensorValue of the same shape with rotary
positional embeddings applied. RotaryEmbedding supports both interleaved and
non-interleaved RoPE variants.
-
Parameters:
-
- dim (int) – The model’s hidden dimension.
- n_heads (int) – The number of attention heads.
- theta (float) – The base for computing RoPE frequencies. Controls the frequency scaling of the sinusoidal components.
- max_seq_len (int) – The maximum sequence length for model input.
- head_dim (int) – The per-head dimension. Defaults to
dim // n_headsifNone. - _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – A pre-computed frequency tensor. Defaults to
None. - interleaved (bool) – Whether to apply RoPE using interleaved complex
representation. Defaults to
True.
compute_scale()
compute_scale(user_scale=None)
Returns the attention scale factor.
dim
dim: int
The model’s hidden dimension.
freqs_cis
property freqs_cis: TensorValue
The reshaped frequency tensor used for applying RoPE.
Retrieves the base frequency tensor from freqs_cis_base() and
reshapes it from (max_seq_len * 2, head_dim // 2, 2) to
(max_seq_len * 2, head_dim).
freqs_cis_base()
freqs_cis_base()
Computes the frequency cosine-sine tensor for Rotary Position Embedding.
Scales the tensor using the theta parameter. Based on
RoFormer: Enhanced Transformer with Rotary Position Embedding.
-
Returns:
-
The frequency tensor with shape
(max_seq_len * 2, head_dim // 2, 2). -
Return type:
head_dim
head_dim: int
The per-head dimension. Equal to dim // n_heads if not specified.
interleaved
interleaved: bool = True
Whether to apply RoPE using interleaved complex representation.
max_seq_len
max_seq_len: int
The maximum sequence length for model input.
n_heads
n_heads: int
The number of attention heads.
theta
theta: float
The base for computing RoPE frequencies. Controls the frequency scaling of the sinusoidal components.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!