Python module
multi_latent_attention
An opaque KV Cache optimized attention mechanism with Rope.
DataParallelLatentAttentionWithRope
class max.nn.attention.multi_latent_attention.DataParallelLatentAttentionWithRope(**kwargs)
Data-parallel implementation of Latent Attention with RoPE.
This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs_cis, input_row_offsets). No collective ops are required; KV-cache remains local to each device.
Notes:
- signal_buffers is accepted for interface parity with the distributed implementation but is not used here.
- Assumes the caller has already distributed xs, kv_collections, freqs_cis, and input_row_offsets so that index i corresponds to device i, with input_row_offsets[i] rebased to start at 0.
Initializes the latent attention layer.
-
Parameters:
-
- rope – The rope layer to borrow the freqs_cis value from.
- num_attention_heads – The number of attention heads.
- num_key_value_heads – Number of key/value heads.
- hidden_size – The dimension of the hidden states.
- kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype – DType of the weights, currently only bfloat16 is supported.
- devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls – Linear class to use for the outputs dense layer.
- scale – Value used to scale the results of the attention output.
- q_lora_rank – Optional LoRA rank for Q projection.
- kv_lora_rank – LoRA rank for KV projections.
- qk_nope_head_dim – Head dimension for non-positional encoding part.
- qk_rope_head_dim – Head dimension for rope part.
- v_head_dim – Head dimension for value.
- buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.
DistributedLatentAttentionWithRope
max.nn.attention.multi_latent_attention.DistributedLatentAttentionWithRope
LatentAttentionWithRope
class max.nn.attention.multi_latent_attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384)
Implementation of Latent Attention with Rope.
Initializes the latent attention layer.
-
Parameters:
-
- rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
- num_attention_heads (int) – The number of attention heads.
- num_key_value_heads (int) – Number of key/value heads.
- hidden_size (int) – The dimension of the hidden states.
- kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype (DType) – DType of the weights, currently only bfloat16 is supported.
- devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
- scale (float | None) – Value used to scale the results of the attention output.
- q_lora_rank (int | None) – Optional LoRA rank for Q projection.
- kv_lora_rank (int) – LoRA rank for KV projections.
- qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
- qk_rope_head_dim (int) – Head dimension for rope part.
- v_head_dim (int) – Head dimension for value.
- buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.
rope
rope: RotaryEmbedding
shard()
shard(devices)
Creates sharded views of this Module across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the Module sharding strategy.
w_uk_uv
property w_uk_uv: list[TensorValue]
The concatenation of q, k, and v weight vectors.
TensorParallelLatentAttentionWithRope
class max.nn.attention.multi_latent_attention.TensorParallelLatentAttentionWithRope(**kwargs)
Distributed tensor parallel implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause the KV-cache to be duplicated across all devices, which is not efficient.
Initializes the latent attention layer.
-
Parameters:
-
- rope – The rope layer to borrow the freqs_cis value from.
- num_attention_heads – The number of attention heads.
- num_key_value_heads – Number of key/value heads.
- hidden_size – The dimension of the hidden states.
- kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype – DType of the weights, currently only bfloat16 is supported.
- devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls – Linear class to use for the outputs dense layer.
- scale – Value used to scale the results of the attention output.
- q_lora_rank – Optional LoRA rank for Q projection.
- kv_lora_rank – LoRA rank for KV projections.
- qk_nope_head_dim – Head dimension for non-positional encoding part.
- qk_rope_head_dim – Head dimension for rope part.
- v_head_dim – Head dimension for value.
- buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!