IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

AttentionWithRope

AttentionWithRope​

class max.nn.attention.AttentionWithRope(*, rope, sharding_strategy=None, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, quant_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06, mask_variant=MHAMaskVariant.CAUSAL_MASK, sliding_window=None, _fuse_rope_and_store=True)

source

Bases: Module, Shardable

Implementation of attention that uses Rotary Position Embedding (RoPE).

Initializes the attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • sharding_strategy (ShardingStrategy | None) – Optional initial sharding strategy.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
  • dtype (DType) – DType of the QKV and output projection weights.
  • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
  • linear_cls (Callable[..., Linear]) – Linear class to use for projections.
  • stacked_qkv (bool) – Whether Q/K/V weights are stacked in a single Weight.
  • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
  • has_bias (bool) – Whether Q/K/V have bias (stacked_qkv forbids bias).
  • quant_config (QuantConfig | None) – Optional quantization config (dynamic or static).
  • clip_qkv (float | None) – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
  • use_qk_norm (bool) – Whether to use RMSNorm on Q/K.
  • rms_norm_eps (float) – Value to use for numerical stability in RMSNorm.
  • _fuse_rope_and_store (bool) – If True (default), emit a single fused rope+split+store custom op. If False, emit separate rope, split, and store ops to test graph compiler fusion.
  • mask_variant (MHAMaskVariant)
  • sliding_window (int | None)

materialize_kv_from_hidden()​

materialize_kv_from_hidden(layer_idx, hidden, kv_collection, freqs_cis, input_row_offsets)

source

Project hidden to K/V and write into the paged KV cache.

Used by speculative-decoding draft models that build their KV cache from external (e.g. target) hidden states.

Parameters:

Return type:

None

qkv_input_scale​

property qkv_input_scale: TensorValue | None

source

The max of q, k, and v scale input vectors.

qkv_weight_scale​

property qkv_weight_scale: TensorValue

source

The max of q, k, and v scale weight vectors.

qkv_weight_scale_2​

property qkv_weight_scale_2: TensorValue | None

source

The max of q, k, and v scale input vectors.

rope​

rope: RotaryEmbedding

source

shard()​

shard(devices)

source

Create sharded views across devices (tensor-parallel).

Returns one AttentionWithRope per device with appropriately sliced weights.

Parameters:

devices (Iterable[DeviceRef])

Return type:

list[AttentionWithRope]

sharding_strategy​

property sharding_strategy: ShardingStrategy | None

source

Get the Module sharding strategy.

wqkv​

property wqkv: TensorValue

source

The concatenation of q, k, and v weight vectors.

wqkv_bias​

property wqkv_bias: TensorValue | None

source

The concatenation of q, k, and v bias weight vectors.