Skip to main content

Python class

AttentionWithRope

AttentionWithRope

class max.nn.attention.AttentionWithRope(*, rope, sharding_strategy=None, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, quant_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06, _fuse_rope_and_store=True)

source

Bases: Module, Shardable

Implementation of attention that uses Rotary Position Embedding (RoPE).

Initializes the attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • sharding_strategy (ShardingStrategy | None) – Optional initial sharding strategy.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
  • dtype (DType) – DType of the QKV and output projection weights.
  • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
  • linear_cls (Callable[..., Linear]) – Linear class to use for projections.
  • stacked_qkv (bool) – Whether Q/K/V weights are stacked in a single Weight.
  • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
  • has_bias (bool) – Whether Q/K/V have bias (stacked_qkv forbids bias).
  • quant_config (QuantConfig | None) – Optional quantization config (dynamic or static).
  • clip_qkv (float | None) – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
  • use_qk_norm (bool) – Whether to use RMSNorm on Q/K.
  • rms_norm_eps (float) – Value to use for numerical stability in RMSNorm.
  • _fuse_rope_and_store (bool) – If True (default), emit a single fused rope+split+store custom op. If False, emit separate rope, split, and store ops to test graph compiler fusion.

qkv_input_scale

property qkv_input_scale: TensorValue | None

source

The max of q, k, and v scale input vectors.

qkv_weight_scale

property qkv_weight_scale: TensorValue

source

The max of q, k, and v scale weight vectors.

qkv_weight_scale_2

property qkv_weight_scale_2: TensorValue | None

source

The max of q, k, and v scale input vectors.

rope

rope: RotaryEmbedding

source

shard()

shard(devices)

source

Create sharded views across devices (tensor-parallel).

Returns one AttentionWithRope per device with appropriately sliced weights.

Parameters:

devices (Iterable[DeviceRef])

Return type:

list[AttentionWithRope]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the Module sharding strategy.

wqkv

property wqkv: TensorValue

source

The concatenation of q, k, and v weight vectors.

wqkv_bias

property wqkv_bias: TensorValue | None

source

The concatenation of q, k, and v bias weight vectors.