Skip to main content

Python class

DynamicRotaryEmbedding

DynamicRotaryEmbedding

class max.nn.DynamicRotaryEmbedding(dim, n_heads, theta, max_seq_len, head_dim=None, _freqs_cis=None, interleaved=True)

source

Bases: RotaryEmbedding

Applies RoPE with dynamic scaling for long-context inference.

Dynamically updates the inverse frequency buffer and the corresponding frequency tensor if the current sequence length exceeds the original maximum, or resets to the original high-precision version for short sequences.

Parameters:

  • dim (int) – The model’s hidden dimension.
  • n_heads (int) – The number of attention heads.
  • theta (float) – The base for computing RoPE frequencies.
  • max_seq_len (int) – The maximum sequence length for model input.
  • head_dim (int) – The per-head dimension. Defaults to dim // n_heads if None.
  • _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – A pre-computed frequency tensor. Defaults to None.
  • interleaved (bool) – Whether to apply RoPE using interleaved complex representation. Defaults to True.

freqs_cis_base()

freqs_cis_base()

source

Computes the frequency cosine-sine tensor using the current inv_freq.

Returns:

The frequency tensor with shape (max_seq_len_cached * 2, head_dim // 2, 2).

Return type:

TensorValue

maybe_update_freqs()

maybe_update_freqs(position_ids)

source

Updates the frequency buffer if the sequence length exceeds the cached maximum.

Reverts to the original high-precision version if the sequence drops back below the original maximum.

Parameters:

position_ids (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The position IDs tensor used to determine the current sequence length.

Return type:

None