For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python class
AttentionWithRope
AttentionWithRopeβ
class max.nn.attention.AttentionWithRope(*, rope, sharding_strategy=None, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, quant_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06, mask_variant=MHAMaskVariant.CAUSAL_MASK, sliding_window=None, _fuse_rope_and_store=True)
Implementation of attention that uses Rotary Position Embedding (RoPE).
Initializes the attention layer.
-
Parameters:
-
- rope (RotaryEmbedding) β The rope layer to borrow the freqs_cis value from.
- sharding_strategy (ShardingStrategy | None) β Optional initial sharding strategy.
- num_attention_heads (int) β The number of attention heads.
- num_key_value_heads (int) β Number of key/value heads.
- hidden_size (int) β The dimension of the hidden states.
- kv_params (KVCacheParams) β KV Cache params, including number of kv heads, head dim, and dtype.
- dtype (DType) β DType of the QKV and output projection weights.
- devices (Sequence[DeviceRef] | None) β Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
- linear_cls (Callable[..., Linear]) β Linear class to use for projections.
- stacked_qkv (bool) β Whether Q/K/V weights are stacked in a single Weight.
- scale (float | None) β Optional attention scale; defaults to sqrt(1/head_dim).
- has_bias (bool) β Whether Q/K/V have bias (stacked_qkv forbids bias).
- quant_config (QuantConfig | None) β Optional quantization config (dynamic or static).
- clip_qkv (float | None) β If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
- use_qk_norm (bool) β Whether to use RMSNorm on Q/K.
- rms_norm_eps (float) β Value to use for numerical stability in RMSNorm.
- _fuse_rope_and_store (bool) β If True (default), emit a single fused rope+split+store custom op. If False, emit separate rope, split, and store ops to test graph compiler fusion.
- mask_variant (MHAMaskVariant)
- sliding_window (int | None)
materialize_kv_from_hidden()β
materialize_kv_from_hidden(layer_idx, hidden, kv_collection, freqs_cis, input_row_offsets)
Project hidden to K/V and write into the paged KV cache.
Used by speculative-decoding draft models that build their KV cache from external (e.g. target) hidden states.
-
Parameters:
-
- layer_idx (TensorValue)
- hidden (TensorValue)
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue])
- freqs_cis (TensorValue)
- input_row_offsets (TensorValue)
-
Return type:
-
None
qkv_input_scaleβ
property qkv_input_scale: TensorValue | None
The max of q, k, and v scale input vectors.
qkv_weight_scaleβ
property qkv_weight_scale: TensorValue
The max of q, k, and v scale weight vectors.
qkv_weight_scale_2β
property qkv_weight_scale_2: TensorValue | None
The max of q, k, and v scale input vectors.
ropeβ
rope: RotaryEmbedding
shard()β
shard(devices)
Create sharded views across devices (tensor-parallel).
Returns one AttentionWithRope per device with appropriately sliced weights.
-
Parameters:
-
Return type:
sharding_strategyβ
property sharding_strategy: ShardingStrategy | None
Get the Module sharding strategy.
wqkvβ
property wqkv: TensorValue
The concatenation of q, k, and v weight vectors.
wqkv_biasβ
property wqkv_bias: TensorValue | None
The concatenation of q, k, and v bias weight vectors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!