Skip to main content
Log in

Python module


An attention layer, using only native max graph operations, the naive cache, and ROPE.


class max.nn.attention.naive_attention_with_rope.NaiveAttentionWithRope(n_heads: int, kv_params: KVCacheParams, dim: int, wq: Linear | LinearV2, wk: Linear | LinearV2, wv: Linear | LinearV2, wo: Linear | LinearV2, rope: RotaryEmbedding, scale: float | None = None)


attention(xq: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, xk: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, xv: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, attn_mask: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, keys: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, values: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue


repeat_kv(kv: TensorValue) → TensorValue

Repeats key/value tensors to match the number of query heads.