Skip to main content

Python class

MultiheadAttention

MultiheadAttention

class max.nn.attention.MultiheadAttention(num_attention_heads, hidden_size, devices=None, dtype=float32, scale=None, qkv_has_bias=False, o_proj_has_bias=False, stacked_qkv=False)

source

Bases: Module

Multihead attention that handles both single and distributed computation.

Parameters:

  • num_attention_heads (int) – The number of attention heads.
  • hidden_size (int) – The dimension of the hidden states (embed_dim).
  • devices (Sequence[DeviceRef] | None) – Device(s) to place the weights and run the computation. If multiple devices provided, uses distributed computation.
  • dtype (DType) – DType of the QKV and output projection weights.
  • scale (float | None) – Value used to scale the results of the attention output.
  • qkv_has_bias (bool) – Whether to use an attention bias.
  • o_proj_has_bias (bool) – Whether to use a bias on the output projection.
  • stacked_qkv (bool) – Whether to use a single stacked QKV weight matrix.