Python class
YarnRotaryEmbedding
YarnRotaryEmbedding
class max.nn.YarnRotaryEmbedding(dim, n_heads, theta, max_seq_len, head_dim=None, _freqs_cis=None, interleaved=True, scaling_params=None)
Bases: RotaryEmbedding
Applies generic YaRN (Yet another RoPE eNhancement) Rotary Position Embedding.
Provides YARN scaling with configurable beta_fast, beta_slow, and
scaling factor parameters.
-
Parameters:
-
- dim (int) – The model’s hidden dimension.
- n_heads (int) – The number of attention heads.
- theta (float) – The base frequency for rotary embeddings.
- max_seq_len (int) – The maximum sequence length for model input.
- head_dim (int | None) – An optional per-head dimension override. Defaults to
None. - _freqs_cis (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray | None) – Optional precomputed frequencies. Defaults to
None. - interleaved (bool) – Whether to use interleaved complex format. Defaults to
True. - scaling_params (YarnScalingParams | None) – The YARN scaling parameters. Defaults to
None.
freqs_cis_base()
freqs_cis_base()
Computes the frequency cosine-sine tensor with YARN scaling applied.
-
Returns:
-
The frequency tensor with shape
(max_seq_len, head_dim // 2, 2). -
Return type:
scaling_params
scaling_params: YarnScalingParams | None = None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!