Skip to main content

Python module

ragged_attention

An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel.

RaggedAttention

class max.nn.legacy.attention.ragged_attention.RaggedAttention(*, mask_variant, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, clip_qkv=None)

Layer that computes the self attention score for ragged inputs.

Parameters:

wqkv

property wqkv: TensorValue

The concatenation of q, k, and v weight vectors.

Was this page helpful?