Skip to main content

Python class

TensorParallelAttentionWithRope

TensorParallelAttentionWithRope

class max.nn.attention.TensorParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, quant_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)

source

Bases: AttentionWithRope, DistributedAttentionImpl

Tensor-parallel wrapper that delegates sharding to the base module.

Initializes the distributed (tensor parallel) attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
  • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. Must provide at least 2 devices for tensor parallel attention.
  • dtype (DType) – DType of the QKV and output projection weights.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • stacked_qkv (bool) – Whether the weights are stacked together.
  • scale (float | None) – Value used to scale the results of the attention output.
  • has_bias (bool) – Whether to use an attention bias.
  • quant_config (QuantConfig | None) – Quantization configuration.
  • clip_qkv (float | None) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv].
  • use_qk_norm (bool) – Whether to use RMSNorm on Q/K.
  • rms_norm_eps (float) – Value to use for numerical stability in RMSNorm.