IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

MultiheadAttention

MultiheadAttention​

class max.nn.attention.MultiheadAttention(num_attention_heads, hidden_size, devices=None, dtype=float32, scale=None, qkv_has_bias=False, o_proj_has_bias=False, stacked_qkv=False)

source

Bases: Module

Multihead attention that handles both single and distributed computation.

Parameters:

  • num_attention_heads (int) – The number of attention heads.
  • hidden_size (int) – The dimension of the hidden states (embed_dim).
  • devices (Sequence[DeviceRef] | None) – Device(s) to place the weights and run the computation. If multiple devices provided, uses distributed computation.
  • dtype (DType) – DType of the QKV and output projection weights.
  • scale (float | None) – Value used to scale the results of the attention output.
  • qkv_has_bias (bool) – Whether to use an attention bias.
  • o_proj_has_bias (bool) – Whether to use a bias on the output projection.
  • stacked_qkv (bool) – Whether to use a single stacked QKV weight matrix.