Python module
rotary_embedding
The rope embedding used within the model.
DeepseekYarnRopeScalingParams
class max.nn.rotary_embedding.DeepseekYarnRopeScalingParams(scaling_factor: float, original_max_position_embeddings: int, beta_fast: int, beta_slow: int, mscale: float, mscale_all_dim: float)
beta_fast
beta_fast*: int*
Fast interpolation rate.
beta_slow
beta_slow*: int*
Slow interpolation rate.
mscale
mscale*: float*
Scaling factor for middle frequencies.
mscale_all_dim
mscale_all_dim*: float*
Scaling factor applied to all dimensions.
original_max_position_embeddings
original_max_position_embeddings*: int*
Original maximum sequence length during training.
scaling_factor
scaling_factor*: float*
Scaling factor for frequency interpolation.
DeepseekYarnRotaryEmbedding
class max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True, scaling_params: DeepseekYarnRopeScalingParams | None = None)
YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer.
This layer implements YaRN rotary position embeddings which extend RoPE to longer sequences. It computes position-dependent rotation matrices using a combination of linear interpolation and frequency scaling to enable extrapolation beyond the original training context length.
Unlike the parent class, this class does not apply frequencies to the input tensor. Instead, it simply returns the frequencies which can be later applied in a kernel.
scaling_params
scaling_params*: DeepseekYarnRopeScalingParams | None* = None
LinearScalingParams
class max.nn.rotary_embedding.LinearScalingParams(factor: float)
factor
factor*: float*
Main scaling factor for the frequency components of the rope.
Llama3RopeScalingParams
class max.nn.rotary_embedding.Llama3RopeScalingParams(factor: float, low_freq_factor: float, high_freq_factor: float, orig_max_position: int)
factor
factor*: float*
Main scaling factor for the frequency components of the rope.
high_freq_factor
high_freq_factor*: float*
Factor to scale the high frequency components of the rope.
low_freq_factor
low_freq_factor*: float*
Factor to scale the low frequency components of the rope.
orig_max_position
orig_max_position*: int*
The original maximum position length supported by the model.
Llama3RotaryEmbedding
class max.nn.rotary_embedding.Llama3RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True, scaling_params: Llama3RopeScalingParams | None = None)
RotaryEmbedding for Llama3 that takes rope scaling into account.
scaling_params
scaling_params*: Llama3RopeScalingParams | None* = None
Scaling parameters to enable llama to function with a longer context length.
OptimizedRotaryEmbedding
class max.nn.rotary_embedding.OptimizedRotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)
Optimized version of RotaryEmbedding using 2D frequency tensor representation.
freqs_cis
property freqs_cis
RotaryEmbedding
class max.nn.rotary_embedding.RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)
RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials.
dim
dim*: int*
freqs_cis
property freqs_cis*: TensorValue*
freqs_cis_base()
freqs_cis_base() → TensorValue
Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).
-
Returns:
The frequency tensor for complex exponentials with shape : (max_seq_len * 2, dim//(2 * n_heads), 2)
interleaved
interleaved*: bool* = True
max_seq_len
max_seq_len*: int*
The maximum sequence length for model’s input.
n_heads
n_heads*: int*
theta
theta*: float*
Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!