IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

TensorParallelAttentionWithRope

TensorParallelAttentionWithRope​

class max.nn.attention.TensorParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, quant_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06, mask_variant=MHAMaskVariant.CAUSAL_MASK, sliding_window=None)

source

Bases: AttentionWithRope, DistributedAttentionImpl

Tensor-parallel wrapper that delegates sharding to the base module.

Initializes the distributed (tensor parallel) attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
  • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. Must provide at least 2 devices for tensor parallel attention.
  • dtype (DType) – DType of the QKV and output projection weights.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • stacked_qkv (bool) – Whether the weights are stacked together.
  • scale (float | None) – Value used to scale the results of the attention output.
  • has_bias (bool) – Whether to use an attention bias.
  • quant_config (QuantConfig | None) – Quantization configuration.
  • clip_qkv (float | None) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv].
  • use_qk_norm (bool) – Whether to use RMSNorm on Q/K.
  • rms_norm_eps (float) – Value to use for numerical stability in RMSNorm.
  • mask_variant (MHAMaskVariant) – Attention mask used by the flash-attention kernel. Defaults to MHAMaskVariant.CAUSAL_MASK.
  • sliding_window (int | None)

materialize_kv_from_hidden()​

materialize_kv_from_hidden(layer_idx, hiddens, kv_collections, freqs_cis, input_row_offsets)

source

Project hidden to K/V and write into the paged KV cache.

Used by speculative-decoding draft models that build their KV cache from external (e.g. target) hidden states.

Parameters:

Return type:

None