Skip to main content

Python module

multihead_attention

MultiheadAttentionโ€‹

class max.nn.legacy.attention.multihead_attention.MultiheadAttention(num_attention_heads, hidden_size, devices=None, dtype=float32, scale=None, qkv_has_bias=False, o_proj_has_bias=False, stacked_qkv=False)

Multihead attention that handles both single and distributed computation.

Parameters:

wqkvโ€‹

property wqkv: TensorValue

The concatenation of q, k, and v weight vectors.

wqkv_biasโ€‹

property wqkv_bias: TensorValue | None

The concatenation of q, k, and v bias weight vectors.

Was this page helpful?