Skip to main content

Python module

multi_latent_attention

An opaque KV Cache optimized attention mechanism with Rope.

DataParallelLatentAttentionWithRope

class max.nn.attention.multi_latent_attention.DataParallelLatentAttentionWithRope(**kwargs)

Data-parallel implementation of Latent Attention with RoPE.

This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs_cis, input_row_offsets). No collective ops are required; KV-cache remains local to each device.

Notes:

  • signal_buffers is accepted for interface parity with the distributed implementation but is not used here.
  • Assumes the caller has already distributed xs, kv_collections, freqs_cis, and input_row_offsets so that index i corresponds to device i, with input_row_offsets[i] rebased to start at 0.

Initializes the latent attention layer.

Parameters:

  • rope – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads – The number of attention heads.
  • num_key_value_heads – Number of key/value heads.
  • hidden_size – The dimension of the hidden states.
  • kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype – DType of the weights, currently only bfloat16 is supported.
  • devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls – Linear class to use for the outputs dense layer.
  • scale – Value used to scale the results of the attention output.
  • q_lora_rank – Optional LoRA rank for Q projection.
  • kv_lora_rank – LoRA rank for KV projections.
  • qk_nope_head_dim – Head dimension for non-positional encoding part.
  • qk_rope_head_dim – Head dimension for rope part.
  • v_head_dim – Head dimension for value.
  • buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.

DistributedLatentAttentionWithRope

max.nn.attention.multi_latent_attention.DistributedLatentAttentionWithRope

alias of TensorParallelLatentAttentionWithRope

LatentAttentionWithRope

class max.nn.attention.multi_latent_attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384)

Implementation of Latent Attention with Rope.

Initializes the latent attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype (DType) – DType of the weights, currently only bfloat16 is supported.
  • devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • scale (float | None) – Value used to scale the results of the attention output.
  • q_lora_rank (int | None) – Optional LoRA rank for Q projection.
  • kv_lora_rank (int) – LoRA rank for KV projections.
  • qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
  • qk_rope_head_dim (int) – Head dimension for rope part.
  • v_head_dim (int) – Head dimension for value.
  • buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.

rope

rope: RotaryEmbedding

shard()

shard(devices)

Creates sharded views of this Module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded LatentAttentionWithRope instances, one for each device.

Return type:

list[LatentAttentionWithRope]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the Module sharding strategy.

w_uk_uv

property w_uk_uv: list[TensorValue]

The concatenation of q, k, and v weight vectors.

TensorParallelLatentAttentionWithRope

class max.nn.attention.multi_latent_attention.TensorParallelLatentAttentionWithRope(**kwargs)

Distributed tensor parallel implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause the KV-cache to be duplicated across all devices, which is not efficient.

Initializes the latent attention layer.

Parameters:

  • rope – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads – The number of attention heads.
  • num_key_value_heads – Number of key/value heads.
  • hidden_size – The dimension of the hidden states.
  • kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype – DType of the weights, currently only bfloat16 is supported.
  • devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls – Linear class to use for the outputs dense layer.
  • scale – Value used to scale the results of the attention output.
  • q_lora_rank – Optional LoRA rank for Q projection.
  • kv_lora_rank – LoRA rank for KV projections.
  • qk_nope_head_dim – Head dimension for non-positional encoding part.
  • qk_rope_head_dim – Head dimension for rope part.
  • v_head_dim – Head dimension for value.
  • buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.

Was this page helpful?