Skip to main content
Log in

Python module

interfaces

General interface for Attention.

AttentionImpl

class max.pipelines.nn.attention.interfaces.AttentionImpl(n_heads: int, kv_params: KVCacheParams, layer_idx: TensorValue, wqkv: TensorValue, wo: Linear)

A generalized attention interface, that will be used upstream by a general Transformer. We would expect a seperate subclass, articulating each variation of Attention:

  • AttentionWithRope
  • AttentionWithAlibi
  • VanillaAttentionWithCausalMask

There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an OptimizedRotaryEmbedding class for the AttentionWithRope class:

@dataclass
class AttentionWithRope(AttentionImpl):
rope: OptimizedRotaryEmbedding
...
@dataclass
class AttentionWithRope(AttentionImpl):
rope: OptimizedRotaryEmbedding
...

We expect the __call__ abstractmethod to remain relatively consistent, however the **kwargs argument is exposed, allowing you to leverage additional arguments for each particular variant. For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention mask:

@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...

def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...

if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)
@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...

def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...

if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)

kv_params

kv_params*: KVCacheParams*

KV Cache Params, including the number of kv heads, the head dim, and data type.

layer_idx

layer_idx*: TensorValue*

The layer number associated with this Attention block.

n_heads

n_heads*: int*

The number of attention heads.

wo

wo*: Linear*

A linear layer for the output projection.

wqkv

wqkv*: TensorValue*

The concatenation of q, k, and v weight vectors.

AttentionImplQKV

class max.pipelines.nn.attention.interfaces.AttentionImplQKV(n_heads: int, kv_params: KVCacheParams, layer_idx: int, wq: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, wk: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, wv: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, wo: Linear)

A generalized attention interface, that will be used upstream by a general Transformer. We would expect a seperate subclass, articulating each variation of Attention:

  • AttentionWithRope
  • AttentionWithAlibi
  • VanillaAttentionWithCausalMask

There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an OptimizedRotaryEmbedding class for the AttentionWithRope class:

@dataclass
class AttentionWithRope(AttentionImpl):
rope: OptimizedRotaryEmbedding
...
@dataclass
class AttentionWithRope(AttentionImpl):
rope: OptimizedRotaryEmbedding
...

We expect the __call__ abstractmethod to remain relatively consistent, however the **kwargs argument is exposed, allowing you to leverage additional arguments for each particular variant. For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention mask:

@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...

def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...

if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)
@dataclass
class VanillaAttentionWithCausalMask(AttentionImpl):
...

def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
valid_lengths: TensorValueLike,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...

if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)

kv_params

kv_params*: KVCacheParams*

KV Cache Params, including the number of kv heads, the head dim, and data type.

layer_idx

layer_idx*: int*

The layer number associated with this Attention block.

n_heads

n_heads*: int*

The number of attention heads.

wk

wk*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*

The k weight vector.

wo

wo*: Linear*

A linear layer for the output projection.

wq

wq*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*

The q weight vector.

wv

wv*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*

The v weight vector.

AttentionImplV2

class max.pipelines.nn.attention.interfaces.AttentionImplV2(num_attention_heads: int, num_key_value_heads: int, hidden_size: int, kv_params: ~max.pipelines.kv_cache.cache_params.KVCacheParams, layer_idx: int, dtype: ~max.dtype.dtype.DType = DType.float32, device: ~max.graph.type.DeviceRef = cpu:0, linear_cls: type[max.pipelines.nn.linear.LinearV2] = <class 'max.pipelines.nn.linear.LinearV2'>)

A generalized attention interface, that will be used upstream by a general Transformer. We would expect a separate subclass, articulating each variation of Attention:

  • AttentionWithRope
  • AttentionWithAlibi
  • VanillaAttentionWithCausalMask

AttentionImplV2 will replace AttentionImpl as we roll out changes to the Layer API.

There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an OptimizedRotaryEmbedding class for the AttentionWithRope class:

@dataclass
class AttentionWithRope(AttentionImplV2):
rope: OptimizedRotaryEmbedding
...
@dataclass
class AttentionWithRope(AttentionImplV2):
rope: OptimizedRotaryEmbedding
...

We expect the __call__ abstractmethod to remain relatively consistent, however the **kwargs argument is exposed, allowing you to leverage additional arguments for each particular variant. For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention mask:

class VanillaAttentionWithCausalMask(AttentionImplV2):
...

def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...

if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)
class VanillaAttentionWithCausalMask(AttentionImplV2):
...

def __call__(
self,
x: TensorValueLike,
kv_collection: ContinuousBatchingKVCacheCollection,
**kwargs,
) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ...

if "attn_mask" not in kwargs:
raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask")

# Which we can then use the attention mask downstream like so:
op(
attn_mask = kwargs["attn_mask"]
)

Initializes the attention layer.

  • Parameters:

    • num_attention_heads – The number of attention heads.
    • num_key_value_heads – Number of key/value heads.
    • hidden_size – The dimension of the hidden states.
    • kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
    • layer_idx – The layer number associated with this Attention block.
    • dtype – DType of the
    • device – Device to place the weights and run the computation.
    • linear_cls – Linear class to use for the outputs dense layer.

wqkv

property wqkv*: TensorValue*

The concatenation of q, k, and v weight vectors.

DistributedAttentionImpl

class max.pipelines.nn.attention.interfaces.DistributedAttentionImpl

A generalized Distributed attention interface.