Skip to main content

Python class

LongRoPERotaryEmbedding

LongRoPERotaryEmbedding

class max.nn.LongRoPERotaryEmbedding(dim, n_heads, theta, max_seq_len, head_dim=None, _freqs_cis=None, interleaved=True, scaling_params=None)

source

Bases: RotaryEmbedding

Applies RoPE with LongRoPE scaling for Phi-3.5 models.

Uses a stitched frequency table where positions up to original_max_position use short_factor scaling and positions beyond use long_factor scaling.

Parameters:

  • dim (int) – The model’s hidden dimension.
  • n_heads (int) – The number of attention heads.
  • theta (float) – The base for computing frequencies. A common value is 10000.0.
  • max_seq_len (int) – The maximum sequence length for model input.
  • head_dim (int) – The per-head dimension. Defaults to dim // n_heads if None.
  • _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – A pre-computed frequency tensor. Defaults to None.
  • interleaved (bool) – Whether to apply RoPE using interleaved complex representation. Defaults to True.
  • scaling_params (LongRoPEScalingParams | None) – The LongRoPE scaling configuration. Defaults to None, in which case standard RoPE is used.

compute_scale()

compute_scale(user_scale=None)

source

Returns the attention scale factor with LongRoPE adjustment.

Applies a logarithmic attention factor when the context length exceeds the original training maximum.

Parameters:

user_scale (float | None) – A custom scale factor. Defaults to None, in which case the scale is computed from head_dim and the LongRoPE attention factor.

Returns:

The attention scale factor.

Return type:

float

freqs_cis_base()

freqs_cis_base()

source

Computes the frequency cosine-sine tensor with LongRoPE scaling.

Creates a stitched table where:

  • Positions 0 to original_max_position use short_factor scaling.
  • Positions from original_max_position onward use long_factor scaling.

Returns:

The frequency tensor with shape (max_seq_len * 2, head_dim // 2, 2).

Return type:

TensorValue