Python module
rotary_embedding
The rope embedding used within the model.
DeepseekYarnRopeScalingParams
class max.nn.rotary_embedding.DeepseekYarnRopeScalingParams(scaling_factor: float, original_max_position_embeddings: int, beta_fast: int, beta_slow: int, mscale: float, mscale_all_dim: float)
-
Parameters:
beta_fast
beta_fast: int
Fast interpolation rate.
beta_slow
beta_slow: int
Slow interpolation rate.
mscale
mscale: float
Scaling factor for middle frequencies.
mscale_all_dim
mscale_all_dim: float
Scaling factor applied to all dimensions.
original_max_position_embeddings
original_max_position_embeddings: int
Original maximum sequence length during training.
scaling_factor
scaling_factor: float
Scaling factor for frequency interpolation.
DeepseekYarnRotaryEmbedding
class max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding(dim, n_heads, theta, max_seq_len, device, head_dim=None, _freqs_cis=None, interleaved=True, scaling_params=None)
Deepseek’s YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer.
Unlike Llama3RotaryEmbedding, the dim argument here is the rope dimension of the model, not the hidden dimension.
-
Parameters:
-
- dim (int)
- n_heads (int)
- theta (float)
- max_seq_len (int)
- device (DeviceRef)
- head_dim (int)
- _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None)
- interleaved (bool)
- scaling_params (DeepseekYarnRopeScalingParams | None)
compute_scale()
compute_scale(user_scale=None)
freqs_cis_base()
freqs_cis_base()
Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).
-
Returns:
-
The frequency tensor for complex exponentials with shape (max_seq_len, rope_dim // 2, 2)
-
Return type:
scaling_params
scaling_params: DeepseekYarnRopeScalingParams | None = None
DynamicRotaryEmbedding
class max.nn.rotary_embedding.DynamicRotaryEmbedding(dim, n_heads, theta, max_seq_len, device, head_dim=None, _freqs_cis=None, interleaved=True)
RotaryEmbedding with dynamic scaling support for long-context inference.
Dynamically updates the inv_freq and corresponding freqs_cis buffer if the current sequence length exceeds the original max, or resets to the original high-precision version for short sequences.
-
Parameters:
freqs_cis_base()
freqs_cis_base()
Computes freqs_cis dynamically using the current self.inv_freq.
-
Return type:
maybe_update_freqs()
maybe_update_freqs(position_ids)
Update freqs_cis if the sequence exceeds max_seq_len_cached, or revert to the original version if back below the threshold.
LinearScalingParams
class max.nn.rotary_embedding.LinearScalingParams(factor: float)
-
Parameters:
-
factor (float)
factor
factor: float
Main scaling factor for the frequency components of the rope.
Llama3RopeScalingParams
class max.nn.rotary_embedding.Llama3RopeScalingParams(factor: float, low_freq_factor: float, high_freq_factor: float, orig_max_position: int)
factor
factor: float
Main scaling factor for the frequency components of the rope.
high_freq_factor
high_freq_factor: float
Factor to scale the high frequency components of the rope.
low_freq_factor
low_freq_factor: float
Factor to scale the low frequency components of the rope.
orig_max_position
orig_max_position: int
The original maximum position length supported by the model.
Llama3RotaryEmbedding
class max.nn.rotary_embedding.Llama3RotaryEmbedding(dim, n_heads, theta, max_seq_len, device, head_dim=None, _freqs_cis=None, interleaved=True, scaling_params=None)
RotaryEmbedding for Llama3 that takes rope scaling into account.
-
Parameters:
-
- dim (int)
- n_heads (int)
- theta (float)
- max_seq_len (int)
- device (DeviceRef)
- head_dim (int)
- _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None)
- interleaved (bool)
- scaling_params (Llama3RopeScalingParams | None)
scaling_params
scaling_params: Llama3RopeScalingParams | None = None
Scaling parameters to enable llama to function with a longer context length.
LongRoPERotaryEmbedding
class max.nn.rotary_embedding.LongRoPERotaryEmbedding(dim, n_heads, theta, max_seq_len, device, head_dim=None, _freqs_cis=None, interleaved=True, scaling_params=None)
Rotary position embedding with LongRoPE scaling for Phi-3.5 models.
Initialize LongRoPE rotary embeddings.
-
Parameters:
-
- dim (int) – Model dimension
- n_heads (int) – Number of attention heads
- theta (float) – Base for computing frequencies (usually 10000.0)
- max_seq_len (int) – Maximum sequence length
- device (DeviceRef) – Device to place tensors on
- head_dim (int) – Head dimension (if None, computed as dim // n_heads)
- _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None) – Pre-computed frequency tensor (optional)
- interleaved (bool) – Whether to use interleaved RoPE weights
- scaling_params (LongRoPEScalingParams | None) – LongRoPE scaling parameters
compute_scale()
compute_scale(user_scale=None)
Compute attention scale with LongRoPE adjustment.
freqs_cis_base()
freqs_cis_base()
Computes the frequency tensor for complex exponentials (cis) with LongRoPE scaling. Creates a “stitched” table where:
- Positions 0 to original_max_position use short_factor
- Positions from original_max_position onwards use long_factor
-
Returns:
-
The frequency tensor for complex exponentials with shape (max_seq_len * 2, head_dim / 2, 2)
-
Return type:
LongRoPEScalingParams
class max.nn.rotary_embedding.LongRoPEScalingParams(short_factor, long_factor, original_max_position, max_position_embeddings)
Parameters for LongRoPE scaling as used in Phi-3.5 models.
-
Parameters:
long_factor
Scaling factors for long sequences (can be much larger).
max_position_embeddings
max_position_embeddings: int
Current max position embeddings after scaling.
original_max_position
original_max_position: int
Original max position embeddings the model was trained with.
short_factor
Scaling factors for short sequences (typically close to 1.0).
RotaryEmbedding
class max.nn.rotary_embedding.RotaryEmbedding(dim, n_heads, theta, max_seq_len, device, head_dim=None, _freqs_cis=None, interleaved=True)
RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials.
-
Parameters:
compute_scale()
compute_scale(user_scale=None)
device
device: DeviceRef
dim
dim: int
freqs_cis
property freqs_cis: TensorValue
freqs_cis_base()
freqs_cis_base()
Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).
-
Returns:
-
The frequency tensor for complex exponentials with shape (max_seq_len * 2, head_dim / 2, 2)
-
Return type:
head_dim
head_dim: int
head_dim = dim // n_heads if not specified in the config.
interleaved
interleaved: bool = True
max_seq_len
max_seq_len: int
The maximum sequence length for model’s input.
n_heads
n_heads: int
theta
theta: float
Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings.
YarnRotaryEmbedding
class max.nn.rotary_embedding.YarnRotaryEmbedding(dim, n_heads, theta, max_seq_len, device, head_dim=None, _freqs_cis=None, interleaved=True, scaling_params=None)
Generic YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer.
This implementation provides YARN scaling for models that require it, with configurable parameters for beta_fast, beta_slow, and scaling factor.
Initialize YarnRotaryEmbedding.
-
Parameters:
-
- dim (int) – The dimension of the rotary embedding (usually hidden_size).
- n_heads (int) – Number of attention heads.
- theta (float) – Base frequency for rotary embeddings.
- max_seq_len (int) – Maximum sequence length.
- device (DeviceRef) – Device to place the embeddings on.
- head_dim (int | None) – Optional head dimension override.
- _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]] | None) – Optional precomputed frequencies.
- interleaved (bool) – Whether to use interleaved complex format.
- scaling_params (YarnScalingParams | None) – YARN scaling parameters.
freqs_cis_base()
freqs_cis_base()
Computes the frequency tensor for complex exponentials (cis) with YARN scaling applied.
-
Return type:
scaling_params
scaling_params: YarnScalingParams | None = None
YarnScalingParams
class max.nn.rotary_embedding.YarnScalingParams(factor: float, beta_fast: float, beta_slow: float, original_max_position_embeddings: int, truncate: bool)
-
Parameters:
beta_fast
beta_fast: float
Yarn parameter for fast frequencies.
beta_slow
beta_slow: float
Yarn parameter for slow frequencies.
factor
factor: float
Main scaling factor for the frequency components of the rope.
original_max_position_embeddings
original_max_position_embeddings: int
The original maximum position length supported by the model.
truncate
truncate: bool
Whether to truncate the frequencies or not.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!