Skip to main content
Log in

Python module

ragged_attention

An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel.

RaggedAttention

class max.nn.attention.ragged_attention.RaggedAttention(*, mask_variant, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, clip_qkv=None)

Layer that computes the self attention score for ragged inputs.

Initializes the attention layer.

Parameters:

  • rope – The rope layer to borrow the freq_cis value from.
  • num_attention_heads (int ) – The number of attention heads.
  • num_key_value_heads (int ) – Number of key/value heads.
  • hidden_size (int ) – The dimension of the hidden states.
  • kv_params (KVCacheParams ) – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype (DType ) – DType of the
  • devices (list [ DeviceRef ] | None ) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls (Callable [ ... , Linear ] ) – Linear class to use for the outputs dense layer.
  • stacked_qkv (bool ) – Whether the weights are stacked together.
  • scale (float | None ) – Value used to scale the results of the attention output.
  • has_bias (bool ) – Whether to use an attention bias.
  • clip_qkv (float | None ) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]
  • mask_variant (MHAMaskVariant )

wqkv

property wqkv*: TensorValue*

The concatenation of q, k, and v weight vectors.