Skip to main content

Python class

TransposedRotaryEmbedding

TransposedRotaryEmbedding​

class max.nn.rope.rope.TransposedRotaryEmbedding(weight)

Parameters:

weight (Tensor)

forward()​

forward(x, start_pos=0)

Applies rotary positional embeddings (RoPE) to x.

The representation of x is transposed within the final dimension compared to traditional RotaryEmbedding.

seq_len is inferred from the shape of x.

Parameters:

  • x (Tensor) – Activation tensor with shape (batch, seq_len, n_kv_heads, head_dim). x is interpreted as a complex number valued tensor where the first half of head_dim are the real parts and the last half are the imaginary parts.
  • start_pos (int | str | Dim | integer[Any]) – starting position of input tensor, defaults to 0 if None

Returns:

Input activation tensor with rotary positional embeddings applied and the same shape as x.

Return type:

Tensor

Was this page helpful?