Skip to main content
Log in

Python module

rotary_embedding

The rope embedding used within the model.

Llama3RopeScalingParams

class max.nn.rotary_embedding.Llama3RopeScalingParams(factor: float, low_freq_factor: float, high_freq_factor: float, orig_max_position: int)

factor

factor*: float*

Main scaling factor for the frequency components of the rope.

high_freq_factor

high_freq_factor*: float*

Factor to scale the high frequency components of the rope.

low_freq_factor

low_freq_factor*: float*

Factor to scale the low frequency components of the rope.

orig_max_position

orig_max_position*: int*

The original maximum position length supported by the model.

Llama3RotaryEmbedding

class max.nn.rotary_embedding.Llama3RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True, scaling_params: Llama3RopeScalingParams | None = None)

RotaryEmbedding for Llama3 that takes rope scaling into account.

scaling_params

scaling_params*: Llama3RopeScalingParams | None* = None

Scaling parameters to enable llama to function with a longer context length.

OptimizedRotaryEmbedding

class max.nn.rotary_embedding.OptimizedRotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)

Optimized version of RotaryEmbedding using 2D frequency tensor representation.

freqs_cis

property freqs_cis

RotaryEmbedding

class max.nn.rotary_embedding.RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)

RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials.

dim

dim*: int*

freqs_cis

property freqs_cis*: TensorValue*

freqs_cis_base()

freqs_cis_base() → TensorValue

Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).

  • Returns:

    The frequency tensor for complex exponentials with shape : (max_seq_len * 2, dim//(2 * n_heads), 2)

interleaved

interleaved*: bool* = True

max_seq_len

max_seq_len*: int*

The maximum sequence length for model’s input.

n_heads

n_heads*: int*

theta

theta*: float*

Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings.