Python module
ragged_attention
An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel.
RaggedAttention
class max.nn.legacy.attention.ragged_attention.RaggedAttention(*, mask_variant, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.legacy.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, clip_qkv=None)
Layer that computes the self attention score for ragged inputs.
-
Parameters:
wqkv
property wqkv: TensorValue
The concatenation of q, k, and v weight vectors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!