IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

GPTQAttentionWithRope

GPTQAttentionWithRope​

class max.nn.attention.GPTQAttentionWithRope(quantization_config, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, scale=None, linear_cls=<class 'max.nn.linear.Linear'>, mask_variant=MHAMaskVariant.CAUSAL_MASK)

source

Bases: AttentionWithRope

Implementation of the GPTQ attention layer.

Parameters:

  • quantization_config (QuantizationConfig) – The GPTQ quantization configuration, including desc_act for activation-order permutation support.
  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – The number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – The KV cache parameters, including number of KV heads, head dim, and dtype.
  • devices (list[DeviceRef] | None) – The device or devices on which to place the weights and run the computation. If multiple are provided, the first device is used.
  • dtype (DType) – The DType for the output projection weights.
  • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
  • linear_cls (Callable[..., Linear]) – The linear class to use for the output projection.
  • mask_variant (MHAMaskVariant)

Initializes the attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • sharding_strategy – Optional initial sharding strategy.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
  • dtype (DType) – DType of the QKV and output projection weights.
  • devices (list[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
  • linear_cls (Callable[..., Linear]) – Linear class to use for projections.
  • stacked_qkv – Whether Q/K/V weights are stacked in a single Weight.
  • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
  • has_bias – Whether Q/K/V have bias (stacked_qkv forbids bias).
  • quant_config – Optional quantization config (dynamic or static).
  • clip_qkv – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
  • use_qk_norm – Whether to use RMSNorm on Q/K.
  • rms_norm_eps – Value to use for numerical stability in RMSNorm.
  • _fuse_rope_and_store – If True (default), emit a single fused rope+split+store custom op. If False, emit separate rope, split, and store ops to test graph compiler fusion.
  • quantization_config (QuantizationConfig)
  • mask_variant (MHAMaskVariant)

wqkv​

property wqkv: TensorValue

source

The concatenation of q, k, and v weight vectors (packed + scales).