Python module
attention_with_rope
An opaque KV Cache optimized attention mechanism with Rope.
AttentionWithRope
class max.nn.legacy.attention.attention_with_rope.AttentionWithRope(*, rope, sharding_strategy=None, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)
Implementation of attention that uses Rotary Position Embedding (RoPE).
-
Parameters:
-
- rope (RotaryEmbedding)
- sharding_strategy (ShardingStrategy | None)
- num_attention_heads (int)
- num_key_value_heads (int)
- hidden_size (int)
- kv_params (KVCacheParams)
- devices (Sequence[DeviceRef] | None)
- dtype (DType)
- linear_cls (Callable[..., Linear])
- stacked_qkv (bool)
- scale (float | None)
- has_bias (bool)
- float8_config (Float8Config | None)
- clip_qkv (float | None)
- use_qk_norm (bool)
- rms_norm_eps (float)
qkv_input_scale
property qkv_input_scale: TensorValue | None
The max of q, k, and v scale input vectors.
qkv_weight_scale
property qkv_weight_scale: TensorValue
The max of q, k, and v scale weight vectors.
qkv_weight_scale_2
property qkv_weight_scale_2: TensorValue | None
The max of q, k, and v scale input vectors.
rope
rope: RotaryEmbedding
shard()
shard(devices)
Create sharded views across devices (tensor-parallel).
Returns one AttentionWithRope per device with appropriately sliced weights.
-
Parameters:
-
Return type:
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the Module sharding strategy.
wqkv
property wqkv: TensorValue
The concatenation of q, k, and v weight vectors.
wqkv_bias
property wqkv_bias: TensorValue | None
The concatenation of q, k, and v bias weight vectors.
AttentionWithRopeNoOpaque
class max.nn.legacy.attention.attention_with_rope.AttentionWithRopeNoOpaque(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, scale=None)
Attention with RoPE without opaque KV cache.
- Assumes:
- no float8
- no stacked qkv
- no bias
- no clip_qkv
- no float8_config
-
Parameters:
-
- rope (RotaryEmbedding)
- num_attention_heads (int)
- num_key_value_heads (int)
- hidden_size (int)
- kv_params (KVCacheParams)
- devices (Sequence[DeviceRef] | None)
- dtype (DType)
- linear_cls (Callable[..., Linear])
- scale (float | None)
rope
rope: RotaryEmbedding
DataParallelAttentionWithRope
class max.nn.legacy.attention.attention_with_rope.DataParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)
Data-parallel implementation of Attention with RoPE.
This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs_cis, input_row_offsets). No collective ops are required; KV-cache remains local to each device.
Notes:
- Assumes the caller has already distributed xs, kv_collections, freqs_cis, and input_row_offsets so that index i corresponds to device i, with input_row_offsets[i] rebased to start at 0.
-
Parameters:
-
- rope (RotaryEmbedding)
- num_attention_heads (int)
- num_key_value_heads (int)
- hidden_size (int)
- kv_params (KVCacheParams)
- devices (Sequence[DeviceRef] | None)
- dtype (DType)
- linear_cls (Callable[..., Linear])
- stacked_qkv (bool)
- scale (float | None)
- has_bias (bool)
- float8_config (Float8Config | None)
- clip_qkv (float | None)
- use_qk_norm (bool)
- rms_norm_eps (float)
GGUFQAttentionWithRope
class max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, quantization_encoding, devices=None, linear_cls=<class 'max.nn.legacy.linear.Linear'>, scale=None, has_bias=False, clip_qkv=None)
Implementation of attention with GGUF quantized weights.
-
Parameters:
-
- rope (RotaryEmbedding)
- num_attention_heads (int)
- num_key_value_heads (int)
- hidden_size (int)
- kv_params (KVCacheParams)
- dtype (DType)
- quantization_encoding (QuantizationEncoding)
- devices (list[DeviceRef] | None)
- linear_cls (Callable[..., Linear])
- scale (float | None)
- has_bias (bool)
- clip_qkv (float | None)
rope
rope: RotaryEmbedding
wqkv
property wqkv: TensorValue
The concatenation of q, k, and v weight vectors.
wqkv_bias
property wqkv_bias: TensorValue | None
The concatenation of q, k, and v bias weight vectors.
GPTQAttentionWithRope
class max.nn.legacy.attention.attention_with_rope.GPTQAttentionWithRope(quantization_config, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, scale=None, linear_cls=<class 'max.nn.legacy.linear.Linear'>)
Implementation of the GPTQ attention layer.
-
Parameters:
-
- quantization_config (QuantizationConfig)
- rope (RotaryEmbedding)
- num_attention_heads (int)
- num_key_value_heads (int)
- hidden_size (int)
- kv_params (KVCacheParams)
- devices (list[DeviceRef] | None)
- dtype (DType)
- scale (float | None)
- linear_cls (Callable[..., Linear])
wqkv
property wqkv: TensorValue
The concatenation of q, k, and v weight vectors (packed + scales).
TensorParallelAttentionWithRope
class max.nn.legacy.attention.attention_with_rope.TensorParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)
Tensor-parallel wrapper that delegates sharding to the base module.
-
Parameters:
-
- rope (RotaryEmbedding)
- num_attention_heads (int)
- num_key_value_heads (int)
- hidden_size (int)
- kv_params (KVCacheParams)
- devices (Sequence[DeviceRef] | None)
- dtype (DType)
- linear_cls (Callable[..., Linear])
- stacked_qkv (bool)
- scale (float | None)
- has_bias (bool)
- float8_config (Float8Config | None)
- clip_qkv (float | None)
- use_qk_norm (bool)
- rms_norm_eps (float)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!