Skip to main content

Python class

RotaryEmbedding

RotaryEmbedding

class max.nn.rope.RotaryEmbedding(weight: max.tensor.Tensor)

Parameters:

weight (Tensor)

dim

property dim: int

forward()

forward(x, start_pos=0)

Applies rotary positional embeddings (RoPE) to x.

seq_len is inferred from the shape of x.

Parameters:

  • x (Tensor) – Activation tensor with shape (batch, seq_len, n_kv_heads, head_dim). x is interpreted as a complex number valued tensor where the head_dim dimension is alternating pairs of (real, imaginary) parts.
  • start_pos (int | str | Dim | integer[Any]) – starting position of input tensor, defaults to 0 if None

Returns:

Input activation tensor with rotary positional embeddings applied and the same shape as x.

Return type:

Tensor

max_sequence_length

property max_sequence_length: int

weight

weight: Tensor

Was this page helpful?