IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

RaggedAttention

RaggedAttention​

class max.nn.attention.RaggedAttention(*, mask_variant, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, clip_qkv=None)

source

Bases: Module

Layer that computes the self attention score for ragged inputs.

Initializes the attention layer.

Parameters:

  • rope – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype (DType) – DType of the
  • devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • stacked_qkv (bool) – Whether the weights are stacked together.
  • scale (float | None) – Value used to scale the results of the attention output.
  • has_bias (bool) – Whether to use an attention bias.
  • clip_qkv (float | None) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]
  • mask_variant (MHAMaskVariant)