Skip to main content
Log in

Python module

kernels

Helper functions for wrapping custom kv cache/attention related ops.

AttentionMaskVariant

class max.pipelines.nn.kernels.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

CAUSAL_MASK

CAUSAL_MASK = 'causal_mask'

NULL_MASK

NULL_MASK = 'null_mask'

TENSOR_MASK

TENSOR_MASK = 'tensor_mask'

MHAMaskConfig

class max.pipelines.nn.kernels.MHAMaskConfig(attention_mask_variant: 'AttentionMaskVariant', positional_encoding_variant: 'PositionalEncodingVariant')

attention_mask_variant

attention_mask_variant*: AttentionMaskVariant*

positional_encoding_variant

positional_encoding_variant*: PositionalEncodingVariant*

MHAMaskVariant

class max.pipelines.nn.kernels.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

CAUSAL_ALIBI_MASK

CAUSAL_ALIBI_MASK = '1'

CAUSAL_MASK

CAUSAL_MASK = '0'

NULL_MASK

NULL_MASK = '2'

PositionalEncodingVariant

class max.pipelines.nn.kernels.PositionalEncodingVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

ALIBI_POS

ALIBI_POS = 'alibi_pos'

NO_POS

NO_POS = 'no_pos'

cross_attention_ragged()

max.pipelines.nn.kernels.cross_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, kv_input_row_offsets: TensorValue, q_max_seq_len: TensorValue) → TensorValue

Computes cross attention provided the !mo.opaque KV Cache.

Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

attention, kv_input_row_offsets represents the KV sequence length.

flash_attention()

max.pipelines.nn.kernels.flash_attention(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, attention_mask: TensorValue, valid_lengths: TensorValue) → TensorValue

Computes flash attention provided the mo.opaque KV Cache.

flash_attention_ragged()

max.pipelines.nn.kernels.flash_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant) → TensorValue

Computes flash (self) attention provided the !mo.opaque KV Cache.

Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.

flash_attention_with_causal_mask()

max.pipelines.nn.kernels.flash_attention_with_causal_mask(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, valid_lengths: TensorValue) → TensorValue

Computes flash attention provided the mo.opaque KV Cache. Notably, materializes the causal mask within the kernel.

fused_qk_ragged_rope()

max.pipelines.nn.kernels.fused_qk_ragged_rope(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, freqs_cis: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue

Computes fused query-key attention with rotary positional encodings and ragged inputs.

input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

fused_qk_rope()

max.pipelines.nn.kernels.fused_qk_rope(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, freqs_cis_2d: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue

Computes fused query-key attention with rotary positional encodings.

fused_qkv_matmul()

max.pipelines.nn.kernels.fused_qkv_matmul(kv_params: KVCacheParams, input: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, n_heads: int) → TensorValue

Computes fused query, key and value projections.

fused_qkv_ragged_matmul()

max.pipelines.nn.kernels.fused_qkv_ragged_matmul(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, bias: TensorValue | None = None) → TensorValue

Computes fused query, key, and value projections with ragged input.

input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

  • Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

fused_qkv_ragged_matmul_quantized()

max.pipelines.nn.kernels.fused_qkv_ragged_matmul_quantized(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, quantization_config: QuantizationConfig, perm_idx: TensorValue | None = None, bias: TensorValue | None = None) → TensorValue

Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.

input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

  • Raises:

    ValueError – on input shapes/dtypes that are invalid for the kernel.

matmul_kv_cache_ragged()

max.pipelines.nn.kernels.matmul_kv_cache_ragged(kv_params: KVCacheParams, hidden_states: TensorValue, input_row_offsets: TensorValue, weight: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: int | integer) → None

Computes key and value projections with ragged input.

hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input

rms_norm_key_cache()

max.pipelines.nn.kernels.rms_norm_key_cache(kv_params: KVCacheParams, kv_collection: ContinuousBatchingKVCacheCollection, gamma: TensorValue, epsilon: float | floating, layer_idx: int | integer, total_seq_len: Dim, input_row_offsets: TensorValue) → None

Computes RMSNorm on the _new_ entries in the KVCache.

Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input_row_offsets to do this bookkeeping.

swish_glu()

max.pipelines.nn.kernels.swish_glu(a: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b0: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b1: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue

Computes swish(a@b0.t()) * (a@b1.t())