Python module
naive_attention_with_rope
An attention layer, using only native max graph operations, the naive cache, and ROPE.
NaiveAttentionWithRope
class max.pipelines.nn.attention.naive_attention_with_rope.NaiveAttentionWithRope(n_heads: int, kv_params: max.pipelines.kv_cache.cache_params.KVCacheParams, dim: int, wq: max.pipelines.nn.linear.Linear, wk: max.pipelines.nn.linear.Linear, wv: max.pipelines.nn.linear.Linear, wo: max.pipelines.nn.linear.Linear, rope: max.pipelines.nn.rotary_embedding.RotaryEmbedding)
attention()
attention(xq: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, xk: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, xv: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, attn_mask: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, keys: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, values: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
dim
dim*: int*
kv_params
kv_params*: KVCacheParams*
n_heads
n_heads*: int*
repeat_kv()
repeat_kv(kv: TensorValue) → TensorValue
Repeats key/value tensors to match the number of query heads.
rope
rope*: RotaryEmbedding*
wk
wk*: Linear*
wo
wo*: Linear*
wq
wq*: Linear*
wv
wv*: Linear*
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!