Skip to main content

Python class

Llama3RotaryEmbedding

Llama3RotaryEmbedding

class max.nn.Llama3RotaryEmbedding(dim, n_heads, theta, max_seq_len, head_dim=None, _freqs_cis=None, interleaved=True, scaling_params=None)

source

Bases: RotaryEmbedding

Applies RoPE with Llama3-style frequency scaling for extended context lengths.

Parameters:

  • dim (int) – The model’s hidden dimension.
  • n_heads (int) – The number of attention heads.
  • theta (float) – The base for computing RoPE frequencies.
  • max_seq_len (int) – The maximum sequence length for model input.
  • head_dim (int) – The per-head dimension. Defaults to dim // n_heads if None.
  • _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – A pre-computed frequency tensor. Defaults to None.
  • interleaved (bool) – Whether to apply RoPE using interleaved complex representation. Defaults to True.
  • scaling_params (Llama3RopeScalingParams | None) – The Llama3 RoPE scaling configuration. Defaults to None, in which case standard RoPE is used.

scaling_params

scaling_params: Llama3RopeScalingParams | None = None

source

The Llama3 RoPE scaling configuration for extended context lengths.