Python module
kernels
Helper functions for wrapping custom kv cache/attention related ops.
AttentionMaskVariant
class max.nn.kernels.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_MASK
CAUSAL_MASK = 'causal_mask'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = 'chunked_causal_mask'
NULL_MASK
NULL_MASK = 'null_mask'
SLIDING_WINDOW_CAUSAL_MASK
SLIDING_WINDOW_CAUSAL_MASK = 'sliding_window_causal_mask'
TENSOR_MASK
TENSOR_MASK = 'tensor_mask'
MHAMaskConfig
class max.nn.kernels.MHAMaskConfig(attention_mask_variant: 'AttentionMaskVariant', positional_encoding_variant: 'PositionalEncodingVariant')
attention_mask_variant
attention_mask_variant*: AttentionMaskVariant*
positional_encoding_variant
positional_encoding_variant*: PositionalEncodingVariant*
MHAMaskVariant
class max.nn.kernels.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_ALIBI_MASK
CAUSAL_ALIBI_MASK = '1'
CAUSAL_MASK
CAUSAL_MASK = '0'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = '3'
NULL_MASK
NULL_MASK = '2'
SLIDING_WINDOW_CAUSAL_MASK
SLIDING_WINDOW_CAUSAL_MASK = '4'
PositionalEncodingVariant
class max.nn.kernels.PositionalEncodingVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
ALIBI_POS
ALIBI_POS = 'alibi_pos'
NO_POS
NO_POS = 'no_pos'
cross_attention_ragged()
max.nn.kernels.cross_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, kv_input_row_offsets: TensorValue, q_max_seq_len: TensorValue, scale: float) → TensorValue
Computes cross attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
attention, kv_input_row_offsets represents the KV sequence length.
dynamic_scaled_matmul()
max.nn.kernels.dynamic_scaled_matmul(a: TensorValue, b: TensorValue, a_scales: TensorValue, b_scales: TensorValue, out_type: DType = DType.bfloat16) → TensorValue
Perform a matmul of two tensors with scaling factors. Currently only supports channel-wise scaling for weights and per-token scaling for inputs.
-
Parameters:
- a – The first tensor to multiply.
- b – The second tensor to multiply, must be transposed.
- a_scales – The scaling factors for the first tensor.
- b_scales – The scaling factors for the second tensor.
-
Returns:
The result of the matmul operation.
flare_mla_decode_ragged()
max.nn.kernels.flare_mla_decode_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, scale: float, qk_rope_dim: int = 64) → TensorValue
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.
flare_mla_decompress_k_cache()
max.nn.kernels.flare_mla_decompress_k_cache(kv_params: KVCacheParams, buffer_row_offsets_1d: TensorValue, cache_offsets_1d: TensorValue, buffer_length: TensorValue, weight: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, buffer_size: int) → TensorValue
This kernel decompresses the key cache by up-projecting latent representations into the KV space using a weight matrix.
The process involves: : 1. Copying buffer_length latent vectors from the key cache into a contiguous buffer (k_latent) 2. Computing k = k_latent @ weight.T to obtain the decompressed keys
-
Returns:
A tensor of shape [buffer_size, weight.shape[0]] containing the decompressed keys. Note that only the first buffer_length tokens are valid.
flare_mla_prefill_plan()
max.nn.kernels.flare_mla_prefill_plan(kv_params: KVCacheParams, input_row_offsets: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, buffer_size: int, max_chunks: int = 16) → tuple[max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue]
This kernel plans how to process a batch of sequences with varying lengths using a fixed-size buffer.
Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_size.
For each chunk (iteration), it calculates: : 1. Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration
flare_mla_prefill_ragged()
max.nn.kernels.flare_mla_prefill_ragged(kv_params: KVCacheParams, input: TensorValue, k: TensorValue, v: TensorValue, input_row_offsets: TensorValue, buffer_row_offsets: TensorValue, cache_offsets: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, scale: float, qk_rope_dim: int = 64, prev_output: TensorValue | None = None, prev_softmax_info: TensorValue | None = None) → tuple[max.graph.value.TensorValue, max.graph.value.TensorValue]
Performs MLA prefill. In the MLA prefill, we need to decompress the KV tensors, as we store the latent representations in the KV cache. We will decompress the KV tensors into a fixed size buffer to avoid out-of-memory errors. In case the total cache length is greater than the buffer size, we will process the attention calculation in chunks.
This MLA prefill kernel will return the output tensor for this iteration and the softmax info tensor for this iteration. Such tensors will be used by the next iteration of the MLA prefill kernel to continue the attention calculation.
-
Parameters:
- kv_params – KVCacheParams
- input – Input tensor
- k – Key tensor
- v – Value tensor
- input_row_offsets – Indicates where each batch starts and ends in input
- buffer_row_offsets – Indicates where each batch starts and ends in the buffer
- cache_offsets – Indicates where each batch starts and ends in the KV cache
- kv_collection – KV collection
- layer_idx – Layer index tensor
- mask_variant – Mask variant
- scale – Scale
- qk_rope_dim – QK rope dimension
- prev_output – Optional. Previous output tensor
- prev_softmax_info – Optional. Previous softmax info tensor
-
Returns:
- The first tensor is the output tensor for this iteration
- The second tensor is the softmax info tensor for this iteration
-
Return type:
A tuple of two tensors
flash_attention()
max.nn.kernels.flash_attention(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, attention_mask: TensorValue, valid_lengths: TensorValue, scale: float) → TensorValue
Computes flash attention provided the mo.opaque KV Cache.
flash_attention_ragged()
max.nn.kernels.flash_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, scale: float, local_window_size: int = 8192) → TensorValue
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.
flash_attention_with_causal_mask()
max.nn.kernels.flash_attention_with_causal_mask(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, valid_lengths: TensorValue, scale: float) → TensorValue
Computes flash attention provided the mo.opaque KV Cache. Notably, materializes the causal mask within the kernel.
fused_qk_ragged_rope()
max.nn.kernels.fused_qk_ragged_rope(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, freqs_cis: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue
Computes fused query-key attention with rotary positional encodings and ragged inputs.
-
Parameters:
- input – [batch_size * seq_len, n_heads, head_dim]
- input_row_offsets –
- freqs_cis – tensor of shape (max_seq_len * 2, head_dim)
- layer_idx –
- interleaved –
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
fused_qk_rope()
max.nn.kernels.fused_qk_rope(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, freqs_cis_2d: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue
Computes fused query-key attention with rotary positional encodings.
fused_qkv_matmul()
max.nn.kernels.fused_qkv_matmul(kv_params: KVCacheParams, input: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, n_heads: int) → TensorValue
Computes fused query, key and value projections.
fused_qkv_ragged_matmul()
max.nn.kernels.fused_qkv_ragged_matmul(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, bias: TensorValue | None = None) → TensorValue
Computes fused query, key, and value projections with ragged input.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
fused_qkv_ragged_matmul_quantized()
max.nn.kernels.fused_qkv_ragged_matmul_quantized(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, quantization_config: QuantizationConfig, perm_idx: TensorValue | None = None, bias: TensorValue | None = None) → TensorValue
Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
fused_qkv_ragged_matmul_scaled_float8()
max.nn.kernels.fused_qkv_ragged_matmul_scaled_float8(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, input_scale: TensorValue, weight_scale: TensorValue, bias: TensorValue | None = None) → TensorValue
Computes fused query, key, and value projections with ragged input.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
grouped_matmul_ragged()
max.nn.kernels.grouped_matmul_ragged(hidden_states: TensorValue, weight: TensorValue, expert_start_indices: TensorValue, expert_ids: TensorValue, expert_usage_stats_host: TensorValue) → TensorValue
Grouped matmul used in MoE layer.
hidden_states and expert_start_indices are used together to implement the ragged tensor. expert_start_indices indicates where each group starts and ends in hidden_states
expert_ids is the id of the expert for each group in hidden_states
expert_usage_stats_host is the maximum number of tokens assigned to any expert, and the number of active experts.
kv_cache_get_max_seq_len()
max.nn.kernels.kv_cache_get_max_seq_len(kv_collection: PagedKVCacheCollection) → TensorValue
This kernel returns the maximum sequence length.
matmul_k_cache_ragged()
max.nn.kernels.matmul_k_cache_ragged(kv_params: KVCacheParams, hidden_states: TensorValue, input_row_offsets: TensorValue, weight: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: int | integer) → None
Computes key projections with ragged input.
hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
matmul_kv_cache_ragged()
max.nn.kernels.matmul_kv_cache_ragged(kv_params: KVCacheParams, hidden_states: TensorValue, input_row_offsets: TensorValue, weight: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: int | integer) → None
Computes key and value projections with ragged input.
hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
matmul_static_scaled_float8()
max.nn.kernels.matmul_static_scaled_float8(input: TensorValue, weight: TensorValue, input_scale: TensorValue, weight_scale: TensorValue) → TensorValue
moe_create_indices()
max.nn.kernels.moe_create_indices(topk_ids: TensorValue, num_local_experts: int) → tuple[max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue]
Creates indices for the MoE layer.
-
Parameters:
- topk_ids – The expert assignments for each token from the router.
- num_local_experts – The number of experts on this device.
-
Returns:
- token_expert_order: The reordered token indices, grouped by assigned expert.
- expert_start_indices: The starting index for each expert’s token group in : the reordered sequence.
- restore_token_order: The indices to restore original token ordering after : expert computation.
- expert_ids: ids of active experts selected for tokens
- expert_usage_stats: The maximum number of tokens assigned to any expert, : and the number of active experts.
-
Return type:
A tuple of four tensors
quantize_dynamic_scaled_float8()
max.nn.kernels.quantize_dynamic_scaled_float8(input: TensorValue, scale_ub: float = 1200.0, group_size_or_per_token: int = -1, out_type: DType = DType.float8_e4m3fn, scales_type: DType = DType.bfloat16) → tuple[max.graph.value.TensorValue, max.graph.value.TensorValue]
Dynamically quantize the input tensor to fp8.
-
Parameters:
- input – The input tensor to quantize.
- scale_ub – The upper bound of the scale factor.
- group_size_or_per_token – The group size for quantization. When set to -1, the quantization is column-wise.
- out_type – The type of the output tensor.
- scales_type – The type of the scales tensor.
-
Returns:
The quantized tensor and the scales.
quantize_static_scaled_float8()
max.nn.kernels.quantize_static_scaled_float8(x: TensorValue, scale: TensorValue, scale_is_inverted: bool = True) → TensorValue
rms_norm_key_cache()
max.nn.kernels.rms_norm_key_cache(kv_params: KVCacheParams, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, gamma: TensorValue, epsilon: float | floating, layer_idx: int | integer, total_seq_len: Dim, input_row_offsets: TensorValue, weight_offset: float | floating, rms_norm_cols: int | None = None) → None
Computes RMSNorm on the _new_ entries in the KVCache.
This function applies RMSNorm to either all dimensions or a subset of dimensions in each head of the key cache. The size of the gamma tensor determines how many dimensions will be normalized. If gamma’s size doesn’t match head_dim, rms_norm_cols must be explicitly specified to confirm the intention to normalize only a subset of dimensions.
Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input_row_offsets to do this bookkeeping.
swish_glu()
max.nn.kernels.swish_glu(a: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b0: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b1: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
unfused_qkv_ragged_matmul_gguf_quantized()
max.nn.kernels.unfused_qkv_ragged_matmul_gguf_quantized(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, n_heads: int, q_weight: TensorValue, k_weight: TensorValue, v_weight: TensorValue, quantization_encoding_q: QuantizationEncoding, quantization_encoding_k: QuantizationEncoding, quantization_encoding_v: QuantizationEncoding, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue) → TensorValue
Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!