Python class
LatentAttentionWithRope
LatentAttentionWithRope
class max.nn.attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, o_proj_dtype=None, o_proj_quant_config=None, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384, graph_mode=None, norm_dtype=None)
Implementation of Latent Attention with Rope.
-
Parameters:
-
- rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
- num_attention_heads (int) – The number of attention heads.
- num_key_value_heads (int) – Number of key/value heads.
- hidden_size (int) – The dimension of the hidden states.
- kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype (DType) – DType of the weights, currently only bfloat16 is supported.
- devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
- o_proj_dtype (DType | None) – Optional dtype override for the output projection.
- o_proj_quant_config (QuantConfig | None) – Optional quantization config for the output projection.
- scale (float | None) – Value used to scale the results of the attention output.
- q_lora_rank (int | None) – Optional LoRA rank for Q projection.
- kv_lora_rank (int) – LoRA rank for KV projections.
- qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
- qk_rope_head_dim (int) – Head dimension for rope part.
- v_head_dim (int) – Head dimension for value.
- buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.
- graph_mode (str | None) – Pipeline role to use for the attention layer. Should be “prefill”, “decode”, or “auto”.
- norm_dtype (DType | None)
create_mla_prefill_metadata()
create_mla_prefill_metadata(input_row_offsets, kv_collection)
Creates the prefill planning metadata required by MLA prefill kernels.
-
Parameters:
-
- input_row_offsets (TensorValue) – Ragged row offsets tensor describing the token boundaries for each sequence in the batch.
- kv_collection (KVCacheInputsPerDevice[TensorValue, BufferValue]) – Paged KV cache values for the current device.
-
Returns:
-
An
MLAPrefillMetadatainstance containing buffer row offsets, cache offsets, and buffer lengths for the prefill step. -
Return type:
-
MLAPrefillMetadata
rope
rope: RotaryEmbedding
shard()
shard(devices)
Creates sharded views of this Module across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the Module sharding strategy.
w_k
property w_k: TensorValue
Returns prefill K-projection weights with shape [H*qk_nope_dim, kv_rank].
w_uk
property w_uk: TensorValue
Returns decode K-projection weights with shape [H, qk_nope_dim, kv_rank].
w_uv
property w_uv: TensorValue
Returns decode V-projection weights with shape [H, kv_rank, v_dim].
wqkv
property wqkv: TensorValue
Returns the concatenation of q_a_proj and kv_a_proj_with_mqa weight vectors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!