Skip to main content
Log in

Python module

attention_with_rope

An opaque KV Cache optimized attention mechanism with Rope.

AttentionWithRope

class max.pipelines.nn.attention.attention_with_rope.AttentionWithRope(n_heads: int, kv_params: max.pipelines.kv_cache.cache_params.KVCacheParams, layer_idx: max.graph.value.TensorValue, wqkv: max.graph.value.TensorValue, wo: max.pipelines.nn.linear.Linear, rope: max.pipelines.nn.rotary_embedding.OptimizedRotaryEmbedding, bias: max.graph.value.TensorValue | None = None)

bias

bias*: TensorValue | None* = None

rope

rope*: OptimizedRotaryEmbedding*

AttentionWithRopeQKV

class max.pipelines.nn.attention.attention_with_rope.AttentionWithRopeQKV(n_heads: int, kv_params: max.pipelines.kv_cache.cache_params.KVCacheParams, layer_idx: int, wq: max._mlir._mlir_libs._mlir.ir.Value | max.graph.value.TensorValue | max.graph.type.Shape | max.graph.type.Dim | int | float | numpy.integer | numpy.floating | numpy.ndarray, wk: max._mlir._mlir_libs._mlir.ir.Value | max.graph.value.TensorValue | max.graph.type.Shape | max.graph.type.Dim | int | float | numpy.integer | numpy.floating | numpy.ndarray, wv: max._mlir._mlir_libs._mlir.ir.Value | max.graph.value.TensorValue | max.graph.type.Shape | max.graph.type.Dim | int | float | numpy.integer | numpy.floating | numpy.ndarray, wo: max.pipelines.nn.linear.Linear, rope: max.pipelines.nn.rotary_embedding.OptimizedRotaryEmbedding)

rope

rope*: OptimizedRotaryEmbedding*

DistributedAttentionWithRope

class max.pipelines.nn.attention.attention_with_rope.DistributedAttentionWithRope(list_of_attentions: List[max.pipelines.nn.attention.attention_with_rope.AttentionWithRope], devices: list[max.graph.type.DeviceRef])

devices

devices*: list[max.graph.type.DeviceRef]*

list_of_attentions

list_of_attentions*: List[AttentionWithRope]*

distribute_value()

max.pipelines.nn.attention.attention_with_rope.distribute_value(v: TensorValue, devices: List[DeviceRef]) → List[TensorValue]