Skip to main content

Python module

interfaces

General interface for Attention.

AttentionImpl

class max.nn.attention.interfaces.AttentionImpl(n_heads, kv_params, wqkv, wo, scale)

A generalized attention interface, that will be used upstream by a general Transformer. We would expect a separate subclass, articulating each variation of Attention:

  • AttentionWithRope
  • AttentionWithAlibi
  • VanillaAttentionWithCausalMask

There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an RotaryEmbedding class for the AttentionWithRope class:

@dataclass
class AttentionWithRope(AttentionImpl):
    rope: RotaryEmbedding
    ...

We expect the __call__ abstractmethod to remain relatively consistent, however the **kwargs argument is exposed, allowing you to leverage additional arguments for each particular variant. For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention mask:

@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
    ...

    def __call__(
        self,
        x: TensorValueLike,
        kv_collection: PagedKVCacheCollection,
        valid_lengths: TensorValueLike,
        **kwargs,
    ) -> tuple[TensorValue, PagedKVCacheCollection]: ...

        if "attn_mask" not in kwargs:
            raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

        # Which we can then use the attention mask downstream like so:
        op(
            attn_mask = kwargs["attn_mask"]
        )

Parameters:

kv_params

kv_params: KVCacheParams

KV Cache Params, including the number of kv heads, the head dim, and data type.

n_heads

n_heads: int

The number of attention heads.

scale

scale: float

The scale factor for the attention.

wo

wo: LinearV1

A linear layer for the output projection.

wqkv

wqkv: TensorValue

The concatenation of q, k, and v weight vectors.

AttentionImplQKV

class max.nn.attention.interfaces.AttentionImplQKV(n_heads, kv_params, wq, wk, wv, wo, scale)

A generalized attention interface, that will be used upstream by a general Transformer. We would expect a separate subclass, articulating each variation of Attention:

  • AttentionWithRope
  • AttentionWithAlibi
  • VanillaAttentionWithCausalMask

There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an RotaryEmbedding class for the AttentionWithRope class:

@dataclass
class AttentionWithRope(AttentionImpl):
    rope: RotaryEmbedding
    ...

We expect the __call__ abstractmethod to remain relatively consistent, however the **kwargs argument is exposed, allowing you to leverage additional arguments for each particular variant. For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention mask:

@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
    ...

    def __call__(
        self,
        x: TensorValueLike,
        kv_collection: PagedKVCacheCollection,
        valid_lengths: TensorValueLike,
        **kwargs,
    ) -> tuple[TensorValue, PagedKVCacheCollection]: ...

        if "attn_mask" not in kwargs:
            raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

        # Which we can then use the attention mask downstream like so:
        op(
            attn_mask = kwargs["attn_mask"]
        )

Parameters:

kv_params

kv_params: KVCacheParams

KV Cache Params, including the number of kv heads, the head dim, and data type.

n_heads

n_heads: int

The number of attention heads.

scale

scale: float

The scale factor for the attention.

wk

wk: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]]

The k weight vector.

wo

wo: LinearV1

A linear layer for the output projection.

wq

wq: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]]

The q weight vector.

wv

wv: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]]

The v weight vector.

DistributedAttentionImpl

class max.nn.attention.interfaces.DistributedAttentionImpl

A generalized Distributed attention interface.

Was this page helpful?