Skip to main content
Log in

Python module

continuous_batching_cache

Continuous Batching enabled KV cache for the Transformer leveraging the mo.opaque pattern.

ContinuousBatchingKVCache

class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCache(value: Value | Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray)

Continuous Mojo KV cache graph value.

ContinuousBatchingKVCacheCollection

class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection(value: Value | Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray)

The graph value for a view of the KV cache.

ContinuousBatchingKVCacheCollectionType

class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollectionType

The graph type for a “view” of the cache for the given sequences in the batch.

This object does not own the underlying buffers in k_cache and v_cache, it’s borrowing them from the BlockWrappers in our ContinuousKVCacheManager. It does own the Pointer[NDBuffer[type, 3]] and valid_lengths buffer

ContinuousBatchingKVCacheManager

class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager(params: KVCacheParams, max_cache_batch_size: int, max_seq_len: int, num_layers: int, devices: List[Device], session: InferenceSession)

block_shape()

block_shape(n_sequences: int) → list[int]

Returns the shape of the KV cache blocks for the given number of sequences.

Defines the 6-dimensional shape of the cache blocks used to store key and value tensors for transformer attention. The dimensions represent: [n_sequences, 2, num_layers, max_seq_len, n_kv_heads_per_device, head_dim] where 2 represents separate storage for keys and values.

  • Parameters:

    n_sequences – Number of sequences that will be cached

  • Returns:

    sequences, key/value split, layers, sequence length, attention heads, and head dimension

  • Return type:

    List describing the shape of the cache blocks with dimensions for

estimated_memory_size()

classmethod estimated_memory_size(params: KVCacheParams, max_cache_batch_size: int, max_seq_len: int, num_layers: int, devices: List[Device]) → int

Returns the estimated total memory usage of the kv cache.

fetch()

fetch(seq_ids: List[int]) → List[tuple[max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor]]

Fetches the KV cache state for the given sequence IDs.

This method retrieves the current cache state for a batch of sequences, including their cache lengths and lookup information. It’s used during token generation to access previously cached key/value pairs.

  • Parameters:

    seq_ids – List of sequence IDs to fetch cache state for. Each ID must be within the max_cache_batch_size and must exist in the current cache.

  • Returns:

    • blocks: Tensor containing the KV cache blocks
    • cache_lengths: Tensor of current cache lengths for each sequence
    • lookup_table: Tensor mapping sequence IDs to cache positions
    • is_cache_empty: Boolean tensor indicating if all sequences have empty caches
  • Return type:

    List of tuples for each device containing

  • Raises:

    ValueError – If any seq_id exceeds max_cache_batch_size or doesn’t exist in cache

increment_cache_lengths()

increment_cache_lengths(kv_cache_inputs: List[tuple[max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor]], prev_model_inputs: tuple[max.driver.tensor.Tensor, ...]) → List[tuple[max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor]]

Prepares cache inputs for the next token in multistep execution.

Updates the cache lengths for the next inference step without requiring device synchronization or memory copies. This is crucial for maintaining performance during multi-token generation.

  • Parameters:

    • kv_cache_inputs – Current cache state tuples (blocks, lengths, lookup, empty flag)
    • prev_model_inputs – Previous model inputs including row offsets
  • Returns:

    Updated cache input tuples with incremented lengths and is_cache_empty=False since only the first step can be context encoding

input_symbols()

input_symbols() → List[tuple[max.graph.type.TensorType, max.graph.type.TensorType, max.graph.type.TensorType, max.graph.type.TensorType]]

Returns the expected input tensor types for fetch on each device.

Defines the tensor specifications needed by the cache implementation, including shapes and data types. This is used for graph construction and validation.

  • Returns:

    • KV cache blocks: 6D tensor for storing keys and values
    • Cache lengths: 1D tensor tracking sequence lengths
    • Lookup table: 1D tensor mapping sequence IDs to cache positions
    • Cache empty flag: Scalar boolean tensor
  • Return type:

    List of tuples for each device containing TensorTypes for

ContinuousBatchingKVCacheType

class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheType

Continuous Mojo KV Cache graph type.

FetchContinuousBatchingKVCacheCollection

class max.pipelines.kv_cache.continuous_batching_cache.FetchContinuousBatchingKVCacheCollection(kv_params: KVCacheParams)