Python class
GPTQAttentionWithRope
GPTQAttentionWithRope
class max.nn.attention.GPTQAttentionWithRope(quantization_config, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, scale=None, linear_cls=<class 'max.nn.linear.Linear'>)
Bases: AttentionWithRope
Implementation of the GPTQ attention layer.
-
Parameters:
-
- quantization_config (QuantizationConfig) – The GPTQ quantization configuration, including
desc_actfor activation-order permutation support. - rope (RotaryEmbedding) – The rope layer to borrow the
freqs_cisvalue from. - num_attention_heads (int) – The number of attention heads.
- num_key_value_heads (int) – The number of key/value heads.
- hidden_size (int) – The dimension of the hidden states.
- kv_params (KVCacheParams) – The KV cache parameters, including number of KV heads, head dim, and dtype.
- devices (list[DeviceRef] | None) – The device or devices on which to place the weights and run the computation. If multiple are provided, the first device is used.
- dtype (DType) – The DType for the output projection weights.
- scale (float | None) – Optional attention scale; defaults to
sqrt(1/head_dim). - linear_cls (Callable[..., Linear]) – The linear class to use for the output projection.
- quantization_config (QuantizationConfig) – The GPTQ quantization configuration, including
Initializes the attention layer.
-
Parameters:
-
- rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
- sharding_strategy – Optional initial sharding strategy.
- num_attention_heads (int) – The number of attention heads.
- num_key_value_heads (int) – Number of key/value heads.
- hidden_size (int) – The dimension of the hidden states.
- kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
- dtype (DType) – DType of the QKV and output projection weights.
- devices (list[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
- linear_cls (Callable[..., Linear]) – Linear class to use for projections.
- stacked_qkv – Whether Q/K/V weights are stacked in a single Weight.
- scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
- has_bias – Whether Q/K/V have bias (stacked_qkv forbids bias).
- quant_config – Optional quantization config (dynamic or static).
- clip_qkv – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
- use_qk_norm – Whether to use RMSNorm on Q/K.
- rms_norm_eps – Value to use for numerical stability in RMSNorm.
- _fuse_rope_and_store – If True (default), emit a single fused rope+split+store custom op. If False, emit separate rope, split, and store ops to test graph compiler fusion.
- quantization_config (QuantizationConfig)
wqkv
property wqkv: TensorValue
The concatenation of q, k, and v weight vectors (packed + scales).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!