For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python class
MultiheadAttention
MultiheadAttentionβ
class max.nn.attention.MultiheadAttention(num_attention_heads, hidden_size, devices=None, dtype=float32, scale=None, qkv_has_bias=False, o_proj_has_bias=False, stacked_qkv=False)
Bases: Module
Multihead attention that handles both single and distributed computation.
-
Parameters:
-
- num_attention_heads (int) β The number of attention heads.
- hidden_size (int) β The dimension of the hidden states (embed_dim).
- devices (Sequence[DeviceRef] | None) β Device(s) to place the weights and run the computation. If multiple devices provided, uses distributed computation.
- dtype (DType) β DType of the QKV and output projection weights.
- scale (float | None) β Value used to scale the results of the attention output.
- qkv_has_bias (bool) β Whether to use an attention bias.
- o_proj_has_bias (bool) β Whether to use a bias on the output projection.
- stacked_qkv (bool) β Whether to use a single stacked QKV weight matrix.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!