Python module
multihead_attention
MultiheadAttentionโ
class max.nn.legacy.attention.multihead_attention.MultiheadAttention(num_attention_heads, hidden_size, devices=None, dtype=float32, scale=None, qkv_has_bias=False, o_proj_has_bias=False, stacked_qkv=False)
Multihead attention that handles both single and distributed computation.
-
Parameters:
wqkvโ
property wqkv: TensorValue
The concatenation of q, k, and v weight vectors.
wqkv_biasโ
property wqkv_bias: TensorValue | None
The concatenation of q, k, and v bias weight vectors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!