Python module
continuous_batching_cache
Continuous Batching enabled KV cache for the Transformer leveraging the mo.opaque pattern.
ContinuousBatchingKVCache
class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCache(value: Value | Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray)
Continuous Mojo KV cache graph value.
ContinuousBatchingKVCacheCollection
class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection(value: Value | Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray)
The graph value for a view of the KV cache.
ContinuousBatchingKVCacheCollectionType
class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollectionType
The graph type for a “view” of the cache for the given sequences in the batch.
This object does not own the underlying buffers in k_cache and v_cache, it’s borrowing them from the BlockWrappers in our ContinuousKVCacheManager. It does own the Pointer[NDBuffer[type, 3]] and valid_lengths buffer
ContinuousBatchingKVCacheManager
class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager(params: KVCacheParams, max_cache_batch_size: int, max_seq_len: int, num_layers: int, devices: List[Device], session: InferenceSession)
block_shape()
Returns the shape of the KV cache blocks for the given number of sequences.
Defines the 6-dimensional shape of the cache blocks used to store key and value tensors for transformer attention. The dimensions represent: [n_sequences, 2, num_layers, max_seq_len, n_kv_heads_per_device, head_dim] where 2 represents separate storage for keys and values.
-
Parameters:
n_sequences – Number of sequences that will be cached
-
Returns:
sequences, key/value split, layers, sequence length, attention heads, and head dimension
-
Return type:
List describing the shape of the cache blocks with dimensions for
estimated_memory_size()
classmethod estimated_memory_size(params: KVCacheParams, max_cache_batch_size: int, max_seq_len: int, num_layers: int, available_cache_memory: int, devices: List[Device]) → int
Returns the estimated total memory usage of the kv cache.
input_symbols()
input_symbols() → List[tuple[max.graph.type.TensorType, max.graph.type.TensorType, max.graph.type.TensorType, max.graph.type.TensorType]]
Returns the expected input tensor types for fetch on each device.
Defines the tensor specifications needed by the cache implementation, including shapes and data types. This is used for graph construction and validation.
-
Returns:
- KV cache blocks: 6D tensor for storing keys and values
- Cache lengths: 1D tensor tracking sequence lengths
- Lookup table: 1D tensor mapping sequence IDs to cache positions
- Maximum lengths: 2D tensor tracking maximum sequence and cache lengths per step.
-
Return type:
List of tuples for each device containing TensorTypes for
ContinuousBatchingKVCacheType
class max.pipelines.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheType
Continuous Mojo KV Cache graph type.
FetchContinuousBatchingKVCacheCollection
class max.pipelines.kv_cache.continuous_batching_cache.FetchContinuousBatchingKVCacheCollection(kv_params: KVCacheParams)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!