Skip to main content

Python class

GPTQAttentionWithRope

GPTQAttentionWithRope

class max.nn.attention.GPTQAttentionWithRope(quantization_config, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, scale=None, linear_cls=<class 'max.nn.linear.Linear'>)

source

Bases: AttentionWithRope

Implementation of the GPTQ attention layer.

Parameters:

  • quantization_config (QuantizationConfig) – The GPTQ quantization configuration, including desc_act for activation-order permutation support.
  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – The number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – The KV cache parameters, including number of KV heads, head dim, and dtype.
  • devices (list[DeviceRef] | None) – The device or devices on which to place the weights and run the computation. If multiple are provided, the first device is used.
  • dtype (DType) – The DType for the output projection weights.
  • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
  • linear_cls (Callable[..., Linear]) – The linear class to use for the output projection.

Initializes the attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • sharding_strategy – Optional initial sharding strategy.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
  • dtype (DType) – DType of the QKV and output projection weights.
  • devices (list[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
  • linear_cls (Callable[..., Linear]) – Linear class to use for projections.
  • stacked_qkv – Whether Q/K/V weights are stacked in a single Weight.
  • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
  • has_bias – Whether Q/K/V have bias (stacked_qkv forbids bias).
  • quant_config – Optional quantization config (dynamic or static).
  • clip_qkv – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
  • use_qk_norm – Whether to use RMSNorm on Q/K.
  • rms_norm_eps – Value to use for numerical stability in RMSNorm.
  • _fuse_rope_and_store – If True (default), emit a single fused rope+split+store custom op. If False, emit separate rope, split, and store ops to test graph compiler fusion.
  • quantization_config (QuantizationConfig)

wqkv

property wqkv: TensorValue

source

The concatenation of q, k, and v weight vectors (packed + scales).