Python module
interfaces
General interface for Attention.
AttentionImpl
class max.nn.attention.interfaces.AttentionImpl(n_heads, kv_params, wqkv, wo, scale)
A generalized attention interface, that will be used upstream by a general Transformer. We would expect a separate subclass, articulating each variation of Attention:
- AttentionWithRope
- AttentionWithAlibi
- VanillaAttentionWithCausalMask
- …
There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an RotaryEmbedding class for the AttentionWithRope class:
@dataclass
class AttentionWithRope(AttentionImpl):
rope: RotaryEmbedding
...
@dataclass
class AttentionWithRope(AttentionImpl):
rope: RotaryEmbedding
...
We expect the __call__
abstractmethod to remain relatively consistent, however the **kwargs
argument is exposed, allowing you to leverage additional arguments for each particular variant.
For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention
mask:
@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...
def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...
if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")
# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)
@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...
def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...
if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")
# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)
-
Parameters:
-
- n_heads (
int
) - kv_params (
KVCacheParams
) - wqkv (
TensorValue
) - wo (
LinearV1
) - scale (
float
)
- n_heads (
kv_params
kv_params: KVCacheParams
KV Cache Params, including the number of kv heads, the head dim, and data type.
n_heads
n_heads: int
The number of attention heads.
scale
scale: float
The scale factor for the attention.
wo
wo: LinearV1
A linear layer for the output projection.
wqkv
wqkv: TensorValue
The concatenation of q, k, and v weight vectors.
AttentionImplQKV
class max.nn.attention.interfaces.AttentionImplQKV(n_heads, kv_params, wq, wk, wv, wo, scale)
A generalized attention interface, that will be used upstream by a general Transformer. We would expect a separate subclass, articulating each variation of Attention:
- AttentionWithRope
- AttentionWithAlibi
- VanillaAttentionWithCausalMask
- …
There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an RotaryEmbedding class for the AttentionWithRope class:
@dataclass
class AttentionWithRope(AttentionImpl):
rope: RotaryEmbedding
...
@dataclass
class AttentionWithRope(AttentionImpl):
rope: RotaryEmbedding
...
We expect the __call__
abstractmethod to remain relatively consistent, however the **kwargs
argument is exposed, allowing you to leverage additional arguments for each particular variant.
For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention
mask:
@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...
def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...
if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")
# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)
@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...
def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...
if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")
# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)
-
Parameters:
-
- n_heads (
int
) - kv_params (
KVCacheParams
) - wq (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
) - wk (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
) - wv (
Value
[
TensorType
]
|
TensorValue
|
Shape
|
Dim
|
int
|
float
|
integer
|
floating
|
ndarray
) - wo (
LinearV1
) - scale (
float
)
- n_heads (
kv_params
kv_params: KVCacheParams
KV Cache Params, including the number of kv heads, the head dim, and data type.
n_heads
n_heads: int
The number of attention heads.
scale
scale: float
The scale factor for the attention.
wk
wk: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray
The k weight vector.
wo
wo: LinearV1
A linear layer for the output projection.
wq
wq: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray
The q weight vector.
wv
wv: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray
The v weight vector.
DistributedAttentionImpl
class max.nn.attention.interfaces.DistributedAttentionImpl
A generalized Distributed attention interface.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!