Skip to main content
Log in

Python module

naive_attention_with_rope

An attention layer, using only native max graph operations, the naive cache, and ROPE.

NaiveAttentionWithRope

class max.pipelines.nn.attention.naive_attention_with_rope.NaiveAttentionWithRope(n_heads: int, kv_params: max.pipelines.kv_cache.cache_params.KVCacheParams, dim: int, wq: max.pipelines.nn.linear.Linear, wk: max.pipelines.nn.linear.Linear, wv: max.pipelines.nn.linear.Linear, wo: max.pipelines.nn.linear.Linear, rope: max.pipelines.nn.rotary_embedding.RotaryEmbedding)

attention()

attention(xq: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, xk: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, xv: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, attn_mask: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, keys: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, values: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue

dim

dim*: int*

kv_params

kv_params*: KVCacheParams*

n_heads

n_heads*: int*

repeat_kv()

repeat_kv(kv: TensorValue) → TensorValue

Repeats key/value tensors to match the number of query heads.

rope

rope*: RotaryEmbedding*

wk

wk*: Linear*

wo

wo*: Linear*

wq

wq*: Linear*

wv

wv*: Linear*