Skip to main content

Python class

LatentAttentionWithRope

LatentAttentionWithRope

class max.nn.attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, o_proj_dtype=None, o_proj_quant_config=None, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384, graph_mode=None, norm_dtype=None)

source

Bases: Module, Shardable

Implementation of Latent Attention with Rope.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype (DType) – DType of the weights, currently only bfloat16 is supported.
  • devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • o_proj_dtype (DType | None) – Optional dtype override for the output projection.
  • o_proj_quant_config (QuantConfig | None) – Optional quantization config for the output projection.
  • scale (float | None) – Value used to scale the results of the attention output.
  • q_lora_rank (int | None) – Optional LoRA rank for Q projection.
  • kv_lora_rank (int) – LoRA rank for KV projections.
  • qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
  • qk_rope_head_dim (int) – Head dimension for rope part.
  • v_head_dim (int) – Head dimension for value.
  • buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.
  • graph_mode (str | None) – Pipeline role to use for the attention layer. Should be “prefill”, “decode”, or “auto”.
  • norm_dtype (DType | None)

create_mla_prefill_metadata()

create_mla_prefill_metadata(input_row_offsets, kv_collection)

source

Creates the prefill planning metadata required by MLA prefill kernels.

Parameters:

Returns:

An MLAPrefillMetadata instance containing buffer row offsets, cache offsets, and buffer lengths for the prefill step.

Return type:

MLAPrefillMetadata

rope

rope: RotaryEmbedding

source

shard()

shard(devices)

source

Creates sharded views of this Module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded LatentAttentionWithRope instances, one for each device.

Return type:

list[LatentAttentionWithRope]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the Module sharding strategy.

w_k

property w_k: TensorValue

source

Returns prefill K-projection weights with shape [H*qk_nope_dim, kv_rank].

w_uk

property w_uk: TensorValue

source

Returns decode K-projection weights with shape [H, qk_nope_dim, kv_rank].

w_uv

property w_uv: TensorValue

source

Returns decode V-projection weights with shape [H, kv_rank, v_dim].

wqkv

property wqkv: TensorValue

source

Returns the concatenation of q_a_proj and kv_a_proj_with_mqa weight vectors.