IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

LatentAttentionWithRope

LatentAttentionWithRope​

class max.nn.attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, o_proj_dtype=None, o_proj_quant_config=None, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384, graph_mode=None, norm_dtype=None)

source

Bases: Module, Shardable

Implementation of Latent Attention with Rope.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype (DType) – DType of the weights, currently only bfloat16 is supported.
  • devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • o_proj_dtype (DType | None) – Optional dtype override for the output projection.
  • o_proj_quant_config (QuantConfig | None) – Optional quantization config for the output projection.
  • scale (float | None) – Value used to scale the results of the attention output.
  • q_lora_rank (int | None) – Optional LoRA rank for Q projection.
  • kv_lora_rank (int) – LoRA rank for KV projections.
  • qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
  • qk_rope_head_dim (int) – Head dimension for rope part.
  • v_head_dim (int) – Head dimension for value.
  • buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.
  • graph_mode (str | None) – Pipeline role to use for the attention layer. Should be β€œprefill”, β€œdecode”, or β€œauto”.
  • norm_dtype (DType | None)

create_mla_prefill_metadata()​

create_mla_prefill_metadata(input_row_offsets, kv_collection)

source

Creates the prefill planning metadata required by MLA prefill kernels.

Parameters:

Returns:

An MLAPrefillMetadata instance containing buffer row offsets, cache offsets, and buffer lengths for the prefill step.

Return type:

MLAPrefillMetadata

rope​

rope: RotaryEmbedding

source

shard()​

shard(devices)

source

Creates sharded views of this Module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded LatentAttentionWithRope instances, one for each device.

Return type:

list[LatentAttentionWithRope]

sharding_strategy​

property sharding_strategy: ShardingStrategy | None

source

Get the Module sharding strategy.

w_k​

property w_k: TensorValue

source

Returns prefill K-projection weights with shape [H*qk_nope_dim, kv_rank].

w_uk​

property w_uk: TensorValue

source

Returns decode K-projection weights with shape [H, qk_nope_dim, kv_rank].

w_uv​

property w_uv: TensorValue

source

Returns decode V-projection weights with shape [H, kv_rank, v_dim].

wqkv​

property wqkv: TensorValue

source

Returns the concatenation of q_a_proj and kv_a_proj_with_mqa weight vectors.