For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python class
LatentAttentionWithRope
LatentAttentionWithRopeβ
class max.nn.attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, o_proj_dtype=None, o_proj_quant_config=None, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384, graph_mode=None, norm_dtype=None)
Implementation of Latent Attention with Rope.
-
Parameters:
-
- rope (RotaryEmbedding) β The rope layer to borrow the freqs_cis value from.
- num_attention_heads (int) β The number of attention heads.
- num_key_value_heads (int) β Number of key/value heads.
- hidden_size (int) β The dimension of the hidden states.
- kv_params (KVCacheParams) β KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype (DType) β DType of the weights, currently only bfloat16 is supported.
- devices (list[DeviceRef] | None) β Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls (Callable[..., Linear]) β Linear class to use for the outputs dense layer.
- o_proj_dtype (DType | None) β Optional dtype override for the output projection.
- o_proj_quant_config (QuantConfig | None) β Optional quantization config for the output projection.
- scale (float | None) β Value used to scale the results of the attention output.
- q_lora_rank (int | None) β Optional LoRA rank for Q projection.
- kv_lora_rank (int) β LoRA rank for KV projections.
- qk_nope_head_dim (int) β Head dimension for non-positional encoding part.
- qk_rope_head_dim (int) β Head dimension for rope part.
- v_head_dim (int) β Head dimension for value.
- buffer_size (int) β Buffer size for storing the temporal results during prefill, in unit of tokens.
- graph_mode (str | None) β Pipeline role to use for the attention layer. Should be βprefillβ, βdecodeβ, or βautoβ.
- norm_dtype (DType | None)
create_mla_prefill_metadata()β
create_mla_prefill_metadata(input_row_offsets, kv_collection)
Creates the prefill planning metadata required by MLA prefill kernels.
-
Parameters:
-
- input_row_offsets (TensorValue) β Ragged row offsets tensor describing the token boundaries for each sequence in the batch.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) β Paged KV cache values for the current device.
-
Returns:
-
An
MLAPrefillMetadatainstance containing buffer row offsets, cache offsets, and buffer lengths for the prefill step. -
Return type:
-
MLAPrefillMetadata
ropeβ
rope: RotaryEmbedding
shard()β
shard(devices)
Creates sharded views of this Module across multiple devices.
sharding_strategyβ
property sharding_strategy: ShardingStrategy | None
Get the Module sharding strategy.
w_kβ
property w_k: TensorValue
Returns prefill K-projection weights with shape [H*qk_nope_dim, kv_rank].
w_ukβ
property w_uk: TensorValue
Returns decode K-projection weights with shape [H, qk_nope_dim, kv_rank].
w_uvβ
property w_uv: TensorValue
Returns decode V-projection weights with shape [H, kv_rank, v_dim].
wqkvβ
property wqkv: TensorValue
Returns the concatenation of q_a_proj and kv_a_proj_with_mqa weight vectors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!