Skip to main content

Python module

attention_with_rope

An opaque KV Cache optimized attention mechanism with Rope.

AttentionWithRope

class max.nn.legacy.attention.attention_with_rope.AttentionWithRope(*, rope, sharding_strategy=None, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)

Implementation of attention that uses Rotary Position Embedding (RoPE).

Parameters:

qkv_input_scale

property qkv_input_scale: TensorValue | None

The max of q, k, and v scale input vectors.

qkv_weight_scale

property qkv_weight_scale: TensorValue

The max of q, k, and v scale weight vectors.

qkv_weight_scale_2

property qkv_weight_scale_2: TensorValue | None

The max of q, k, and v scale input vectors.

rope

rope: RotaryEmbedding

shard()

shard(devices)

Create sharded views across devices (tensor-parallel).

Returns one AttentionWithRope per device with appropriately sliced weights.

Parameters:

devices (Iterable[DeviceRef])

Return type:

list[AttentionWithRope]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the Module sharding strategy.

wqkv

property wqkv: TensorValue

The concatenation of q, k, and v weight vectors.

wqkv_bias

property wqkv_bias: TensorValue | None

The concatenation of q, k, and v bias weight vectors.

AttentionWithRopeNoOpaque

class max.nn.legacy.attention.attention_with_rope.AttentionWithRopeNoOpaque(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, scale=None)

Attention with RoPE without opaque KV cache.

Assumes:
  • no float8
    • no stacked qkv
    • no bias
    • no clip_qkv
    • no float8_config

    Parameters:

    rope

    rope: RotaryEmbedding

    DataParallelAttentionWithRope

    class max.nn.legacy.attention.attention_with_rope.DataParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)

    Data-parallel implementation of Attention with RoPE.

    This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs_cis, input_row_offsets). No collective ops are required; KV-cache remains local to each device.

    Notes:

    • Assumes the caller has already distributed xs, kv_collections, freqs_cis, and input_row_offsets so that index i corresponds to device i, with input_row_offsets[i] rebased to start at 0.

    Parameters:

    GGUFQAttentionWithRope

    class max.nn.legacy.attention.attention_with_rope.GGUFQAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, quantization_encoding, devices=None, linear_cls=<class 'max.nn.legacy.linear.Linear'>, scale=None, has_bias=False, clip_qkv=None)

    Implementation of attention with GGUF quantized weights.

    Parameters:

    rope

    rope: RotaryEmbedding

    wqkv

    property wqkv: TensorValue

    The concatenation of q, k, and v weight vectors.

    wqkv_bias

    property wqkv_bias: TensorValue | None

    The concatenation of q, k, and v bias weight vectors.

    GPTQAttentionWithRope

    class max.nn.legacy.attention.attention_with_rope.GPTQAttentionWithRope(quantization_config, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, scale=None, linear_cls=<class 'max.nn.legacy.linear.Linear'>)

    Implementation of the GPTQ attention layer.

    Parameters:

    wqkv

    property wqkv: TensorValue

    The concatenation of q, k, and v weight vectors (packed + scales).

    TensorParallelAttentionWithRope

    class max.nn.legacy.attention.attention_with_rope.TensorParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)

    Tensor-parallel wrapper that delegates sharding to the base module.

    Parameters:

    Was this page helpful?