Python module
attention_with_rope
An opaque KV Cache optimized attention mechanism with Rope.
AttentionWithRope
class max.nn.attention.attention_with_rope.AttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None)
Implementation of attention that uses the rope frequency.
Initializes the attention layer.
-
Parameters:
-
- rope (
OptimizedRotaryEmbedding
) – The rope layer to borrow the freq_cis value from. - num_attention_heads (
int
) – The number of attention heads. - num_key_value_heads (
int
) – Number of key/value heads. - hidden_size (
int
) – The dimension of the hidden states. - kv_params (
KVCacheParams
) – KV Cache Params, including the number of kv heads, the head dim, and data type. - dtype (
DType
) – DType of the QKV and output projection weights. - devices (
list
[
DeviceRef
]
|
None
) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation. - linear_cls (
Callable
[
...
,
Linear
]
) – Linear class to use for the outputs dense layer. - stacked_qkv (
bool
) – Whether the weights are stacked together. - scale (
float
|
None
) – Value used to scale the results of the attention output. - has_bias (
bool
) – Whether to use an attention bias. - clip_qkv (
float
|
None
) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv] - float8_config (
Float8Config
|
None
)
- rope (
qkv_input_scale
property qkv_input_scale*: TensorValue | None*
The max of q, k, and v scale input vectors.
qkv_weight_scale
property qkv_weight_scale*: TensorValue*
The max of q, k, and v scale weight vectors.
rope
rope*: OptimizedRotaryEmbedding*
wqkv
property wqkv*: TensorValue*
The concatenation of q, k, and v weight vectors.
wqkv_bias
property wqkv_bias*: TensorValue | None*
The concatenation of q, k, and v bias weight vectors.
AttentionWithRopeQKV
class max.nn.attention.attention_with_rope.AttentionWithRopeQKV(n_heads: 'int', kv_params: 'KVCacheParams', wq: 'TensorValueLike', wk: 'TensorValueLike', wv: 'TensorValueLike', wo: 'LinearV1', scale: 'float', rope: 'OptimizedRotaryEmbedding')
-
Parameters:
-
- n_heads (
int
) - kv_params (
KVCacheParams
) - wq (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
) - wk (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
) - wv (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
) - wo (
LinearV1
) - scale (
float
) - rope (
OptimizedRotaryEmbedding
)
- n_heads (
rope
rope*: OptimizedRotaryEmbedding*
AttentionWithRopeV1
class max.nn.attention.attention_with_rope.AttentionWithRopeV1(n_heads, kv_params, wqkv, wo, scale, rope, bias=None, perm_idx=None, quantization_config=None)
Implementation of attention that uses the rope frequency.
Deprecated: Use AttentionWithRope instead.
-
Parameters:
-
- n_heads (
int
) - kv_params (
KVCacheParams
) - wqkv (
TensorValue
) - wo (
LinearV1
) - scale (
float
) - rope (
OptimizedRotaryEmbedding
) - bias (
TensorValue
|
None
) - perm_idx (
TensorValue
|
None
) - quantization_config (
QuantizationConfig
|
None
)
- n_heads (
bias
bias*: TensorValue | None* = None
perm_idx
perm_idx*: TensorValue | None* = None
quantization_config
quantization_config*: QuantizationConfig | None* = None
rope
rope*: OptimizedRotaryEmbedding*
DistributedAttentionWithRope
class max.nn.attention.attention_with_rope.DistributedAttentionWithRope(**kwargs)
Initializes the attention layer.
-
Parameters:
-
- rope – The rope layer to borrow the freq_cis value from.
- num_attention_heads – The number of attention heads.
- num_key_value_heads – Number of key/value heads.
- hidden_size – The dimension of the hidden states.
- kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype – DType of the QKV and output projection weights.
- devices – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation.
- linear_cls – Linear class to use for the outputs dense layer.
- stacked_qkv – Whether the weights are stacked together.
- scale – Value used to scale the results of the attention output.
- has_bias – Whether to use an attention bias.
- clip_qkv – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]
GGUFQAttentionWithRope
class max.nn.attention.attention_with_rope.GGUFQAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, quantization_encoding, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, has_bias=False, clip_qkv=None)
Implementation of attention with GGUF quantized weights.
Initializes the attention layer.
-
Parameters:
-
- rope (
OptimizedRotaryEmbedding
) – The rope layer to borrow the freq_cis value from. - num_attention_heads (
int
) – The number of attention heads. - num_key_value_heads (
int
) – Number of key/value heads. - hidden_size (
int
) – The dimension of the hidden states. - kv_params (
KVCacheParams
) – KV Cache Params, including the number of kv heads, the head dim, and data type. - layer_idx – The layer number associated with this Attention block.
- dtype (
DType
) – DType of the weights, should always be uint8. - devices (
list
[
DeviceRef
]
|
None
) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation. - quantization_encoding (
QuantizationEncoding
) – Quantization encoding of the weights. - linear_cls (
Callable
[
...
,
Linear
]
) – Linear class to use for the outputs dense layer. - scale (
float
|
None
) – Value used to scale the results of the attention output. - has_bias (
bool
) – Whether to use an attention bias. - clip_qkv (
float
|
None
) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]
- rope (
rope
rope*: OptimizedRotaryEmbedding*
wqkv
property wqkv*: TensorValue*
The concatenation of q, k, and v weight vectors.
wqkv_bias
property wqkv_bias*: TensorValue | None*
The concatenation of q, k, and v bias weight vectors.
GPTQAttentionWithRope
class max.nn.attention.attention_with_rope.GPTQAttentionWithRope(quantization_config, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, scale=None, linear_cls=<class 'max.nn.linear.Linear'>)
Implementation of the GPT-Q attention layer.
Initializes the attention layer.
-
Parameters:
-
- rope (
OptimizedRotaryEmbedding
) – The rope layer to borrow the freq_cis value from. - num_attention_heads (
int
) – The number of attention heads. - num_key_value_heads (
int
) – Number of key/value heads. - hidden_size (
int
) – The dimension of the hidden states. - kv_params (
KVCacheParams
) – KV Cache Params, including the number of kv heads, the head dim, and data type. - dtype (
DType
) – DType of the QKV and output projection weights. - devices (
list
[
DeviceRef
]
|
None
) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation. - linear_cls (
Callable
[
...
,
Linear
]
) – Linear class to use for the outputs dense layer. - stacked_qkv – Whether the weights are stacked together.
- scale (
float
|
None
) – Value used to scale the results of the attention output. - has_bias – Whether to use an attention bias.
- clip_qkv – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]
- quantization_config (
QuantizationConfig
)
- rope (
wqkv
property wqkv*: TensorValue*
The concatenation of q, k, and v weight vectors.
LatentAttentionWithRope
class max.nn.attention.attention_with_rope.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, has_bias=False, clip_qkv=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384)
Implementation of Latent Attention with Rope.
Initializes the attention layer.
-
Parameters:
-
- rope (
OptimizedRotaryEmbedding
) – The rope layer to borrow the freq_cis value from. - num_attention_heads (
int
) – The number of attention heads. - num_key_value_heads (
int
) – Number of key/value heads. - hidden_size (
int
) – The dimension of the hidden states. - kv_params (
KVCacheParams
) – KV Cache Params, including the number of kv heads, the head dim, and data type. - layer_idx – The layer number associated with this Attention block.
- dtype (
DType
) – DType of the weights, should always be uint8. - devices (
list
[
DeviceRef
]
|
None
) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation. - quantization_encoding – Quantization encoding of the weights.
- linear_cls (
Callable
[
...
,
Linear
]
) – Linear class to use for the outputs dense layer. - scale (
float
|
None
) – Value used to scale the results of the attention output. - has_bias (
bool
) – Whether to use an attention bias. - clip_qkv (
float
|
None
) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv] - buffer_size (
int
) – Buffer size for storing the temporal results during prefill, in unit of tokens. - q_lora_rank (
int
|
None
) - kv_lora_rank (
int
) - qk_nope_head_dim (
int
) - qk_rope_head_dim (
int
) - v_head_dim (
int
)
- rope (
rope
rope*: OptimizedRotaryEmbedding*
w_uk_uv
property w_uk_uv*: list[TensorValue]*
The concatenation of q, k, and v weight vectors.
wqkv
property wqkv*: TensorValue*
The concatenation of q, k, and v weight vectors.
wqkv_bias
property wqkv_bias*: TensorValue | None*
The concatenation of q, k, and v bias weight vectors.
distribute_value()
max.nn.attention.attention_with_rope.distribute_value(v, devices)
-
Parameters:
-
- v (
TensorValue
) - devices (
list
[
DeviceRef
]
)
- v (
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!