Python module
ragged_attention
An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel.
RaggedAttention
class max.nn.attention.ragged_attention.RaggedAttention(*, mask_variant, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, clip_qkv=None)
Layer that computes the self attention score for ragged inputs.
Initializes the attention layer.
-
Parameters:
-
- rope – The rope layer to borrow the freq_cis value from.
- num_attention_heads (
int
) – The number of attention heads. - num_key_value_heads (
int
) – Number of key/value heads. - hidden_size (
int
) – The dimension of the hidden states. - kv_params (
KVCacheParams
) – KV Cache Params, including the number of kv heads, the head dim, and data type. - dtype (
DType
) – DType of the - devices (
list
[
DeviceRef
]
|
None
) – Device to place the weights and run the computation. If multiple are provided, the first device is used. - linear_cls (
Callable
[
...
,
Linear
]
) – Linear class to use for the outputs dense layer. - stacked_qkv (
bool
) – Whether the weights are stacked together. - scale (
float
|
None
) – Value used to scale the results of the attention output. - has_bias (
bool
) – Whether to use an attention bias. - clip_qkv (
float
|
None
) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv] - mask_variant (
MHAMaskVariant
)
wqkv
property wqkv*: TensorValue*
The concatenation of q, k, and v weight vectors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!