Python module
multi_latent_attention
An opaque KV Cache optimized attention mechanism with Rope.
DataParallelLatentAttentionWithRope
class max.nn.attention.multi_latent_attention.DataParallelLatentAttentionWithRope(**kwargs)
Data-parallel implementation of Latent Attention with RoPE.
This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs_cis, input_row_offsets). No collective ops are required; KV-cache remains local to each device.
Notes:
- signal_buffers is accepted for interface parity with the distributed implementation but is not used here.
- Assumes the caller has already distributed xs, kv_collections, freqs_cis, and input_row_offsets so that index i corresponds to device i, with input_row_offsets[i] rebased to start at 0.
Initializes the latent attention layer.
-
Parameters:
-
- rope – The rope layer to borrow the freqs_cis value from.
- num_attention_heads – The number of attention heads.
- num_key_value_heads – Number of key/value heads.
- hidden_size – The dimension of the hidden states.
- kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype – DType of the weights, currently only bfloat16 is supported.
- devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls – Linear class to use for the outputs dense layer.
- scale – Value used to scale the results of the attention output.
- q_lora_rank – Optional LoRA rank for Q projection.
- kv_lora_rank – LoRA rank for KV projections.
- qk_nope_head_dim – Head dimension for non-positional encoding part.
- qk_rope_head_dim – Head dimension for rope part.
- v_head_dim – Head dimension for value.
- buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.
- graph_mode – Pipeline role to use for the attention layer. Should be “prefill”, “decode”, or “auto”.
create_mla_inputs()
create_mla_inputs(input_row_offsets_, kv_collections)
-
Parameters:
-
- input_row_offsets_ (list[TensorValue])
- kv_collections (list[PagedCacheValues])
-
Return type:
LatentAttentionWithRope
class max.nn.attention.multi_latent_attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384, graph_mode=None)
Implementation of Latent Attention with Rope.
Initializes the latent attention layer.
-
Parameters:
-
- rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
- num_attention_heads (int) – The number of attention heads.
- num_key_value_heads (int) – Number of key/value heads.
- hidden_size (int) – The dimension of the hidden states.
- kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype (DType) – DType of the weights, currently only bfloat16 is supported.
- devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
- scale (float | None) – Value used to scale the results of the attention output.
- q_lora_rank (int | None) – Optional LoRA rank for Q projection.
- kv_lora_rank (int) – LoRA rank for KV projections.
- qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
- qk_rope_head_dim (int) – Head dimension for rope part.
- v_head_dim (int) – Head dimension for value.
- buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.
- graph_mode (str | None) – Pipeline role to use for the attention layer. Should be “prefill”, “decode”, or “auto”.
create_mla_inputs()
create_mla_inputs(input_row_offsets, kv_collection)
-
Parameters:
-
- input_row_offsets (TensorValue)
- kv_collection (PagedCacheValues)
-
Return type:
rope
rope: RotaryEmbedding
shard()
shard(devices)
Creates sharded views of this Module across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the Module sharding strategy.
w_uk_uv
property w_uk_uv: list[TensorValue]
The concatenation of q, k, and v weight vectors.
TensorParallelLatentAttentionWithRope
class max.nn.attention.multi_latent_attention.TensorParallelLatentAttentionWithRope(**kwargs)
Distributed tensor parallel implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause the KV-cache to be duplicated across all devices, which is not efficient.
Initializes the latent attention layer.
-
Parameters:
-
- rope – The rope layer to borrow the freqs_cis value from.
- num_attention_heads – The number of attention heads.
- num_key_value_heads – Number of key/value heads.
- hidden_size – The dimension of the hidden states.
- kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype – DType of the weights, currently only bfloat16 is supported.
- devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls – Linear class to use for the outputs dense layer.
- scale – Value used to scale the results of the attention output.
- q_lora_rank – Optional LoRA rank for Q projection.
- kv_lora_rank – LoRA rank for KV projections.
- qk_nope_head_dim – Head dimension for non-positional encoding part.
- qk_rope_head_dim – Head dimension for rope part.
- v_head_dim – Head dimension for value.
- buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.
- graph_mode – Pipeline role to use for the attention layer. Should be “prefill”, “decode”, or “auto”.
create_mla_inputs()
create_mla_inputs(input_row_offsets_, kv_collections)
-
Parameters:
-
- input_row_offsets_ (list[TensorValue])
- kv_collections (list[PagedCacheValues])
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!