Python module
kernels
Helper functions for wrapping custom kv cache/attention related ops.
AttentionMaskVariant
class max.pipelines.nn.kernels.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_MASK
CAUSAL_MASK = 'causal_mask'
NULL_MASK
NULL_MASK = 'null_mask'
TENSOR_MASK
TENSOR_MASK = 'tensor_mask'
MHAMaskConfig
class max.pipelines.nn.kernels.MHAMaskConfig(attention_mask_variant: 'AttentionMaskVariant', positional_encoding_variant: 'PositionalEncodingVariant')
attention_mask_variant
attention_mask_variant*: AttentionMaskVariant*
positional_encoding_variant
positional_encoding_variant*: PositionalEncodingVariant*
MHAMaskVariant
class max.pipelines.nn.kernels.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_ALIBI_MASK
CAUSAL_ALIBI_MASK = '1'
CAUSAL_MASK
CAUSAL_MASK = '0'
NULL_MASK
NULL_MASK = '2'
PositionalEncodingVariant
class max.pipelines.nn.kernels.PositionalEncodingVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
ALIBI_POS
ALIBI_POS = 'alibi_pos'
NO_POS
NO_POS = 'no_pos'
cross_attention_ragged()
max.pipelines.nn.kernels.cross_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, kv_input_row_offsets: TensorValue, q_max_seq_len: TensorValue) → TensorValue
Computes cross attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
attention, kv_input_row_offsets represents the KV sequence length.
flash_attention()
max.pipelines.nn.kernels.flash_attention(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, attention_mask: TensorValue, valid_lengths: TensorValue) → TensorValue
Computes flash attention provided the mo.opaque KV Cache.
flash_attention_ragged()
max.pipelines.nn.kernels.flash_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant) → TensorValue
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.
flash_attention_with_causal_mask()
max.pipelines.nn.kernels.flash_attention_with_causal_mask(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, valid_lengths: TensorValue) → TensorValue
Computes flash attention provided the mo.opaque KV Cache. Notably, materializes the causal mask within the kernel.
fused_qk_ragged_rope()
max.pipelines.nn.kernels.fused_qk_ragged_rope(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, freqs_cis: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue
Computes fused query-key attention with rotary positional encodings and ragged inputs.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
fused_qk_rope()
max.pipelines.nn.kernels.fused_qk_rope(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, freqs_cis_2d: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue
Computes fused query-key attention with rotary positional encodings.
fused_qkv_matmul()
max.pipelines.nn.kernels.fused_qkv_matmul(kv_params: KVCacheParams, input: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, n_heads: int) → TensorValue
Computes fused query, key and value projections.
fused_qkv_ragged_matmul()
max.pipelines.nn.kernels.fused_qkv_ragged_matmul(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, bias: TensorValue | None = None) → TensorValue
Computes fused query, key, and value projections with ragged input.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
fused_qkv_ragged_matmul_quantized()
max.pipelines.nn.kernels.fused_qkv_ragged_matmul_quantized(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, quantization_config: QuantizationConfig, perm_idx: TensorValue | None = None, bias: TensorValue | None = None) → TensorValue
Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
matmul_kv_cache_ragged()
max.pipelines.nn.kernels.matmul_kv_cache_ragged(kv_params: KVCacheParams, hidden_states: TensorValue, input_row_offsets: TensorValue, weight: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: int | integer) → None
Computes key and value projections with ragged input.
hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
rms_norm_key_cache()
max.pipelines.nn.kernels.rms_norm_key_cache(kv_params: KVCacheParams, kv_collection: ContinuousBatchingKVCacheCollection, gamma: TensorValue, epsilon: float | floating, layer_idx: int | integer, total_seq_len: Dim, input_row_offsets: TensorValue) → None
Computes RMSNorm on the _new_ entries in the KVCache.
Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input_row_offsets to do this bookkeeping.
swish_glu()
max.pipelines.nn.kernels.swish_glu(a: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b0: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b1: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!