Python module
cache_manager
PagedKVCacheManager
class max.kv_cache.paged_cache.cache_manager.PagedKVCacheManager(params, max_batch_size, max_seq_len, num_layers, devices, session, available_cache_memory, zmq_endpoint_base=None, page_size=128, enable_runtime_checks=False)
Paged KVCache manager with data and tensor parallelism support.
kv_manager.external_claim(ctx1.request_id, replica_idx=0)
kv_manager.external_claim(ctx2.request_id, replica_idx=1)
# Allocate blocks for these requests
kv_manager.maybe_reserve(ctx1, num_steps=10)
kv_manager.maybe_reserve(ctx2, num_steps=10)
# Get KVCache inputs to feed to graph
kv_cache_inputs = kv_manager.fetch([ctx1, ctx2], num_steps=10)
# Run model...
# Update requests with newly generated tokens
ctx1.update(42)
ctx2.update(42)
# Commit newly written blocks to prefix cache
kv_manager.step([ctx1, ctx2])
# Release metadata and KV blocks for these requests
kv_manager.release(ctx1.request_id)
kv_manager.release(ctx2.request_id)Initialize the multi-device paged KV cache manager.
-
Parameters:
-
- params (KVCacheParams) – KV cache parameters including data parallelism settings
- max_batch_size (int) – The maximum number of active requests that the manager should support. Note that this is the global maximum batch size across all devices, so when data parallelism is enabled, this would be split across all replicas of the cache.
- max_seq_len (int) – Maximum sequence length
- num_layers (int) – Number of model layers
- devices (Sequence[Device]) – The devices to use for the KV cache manager. If data
parallelism is enabled, the devices will be split into
params.data_parallel_degreegroups. - session (InferenceSession) – Inference session
- available_cache_memory (int) – Total cache memory across all devices
- page_size (int) – Page size in tokens
- enable_runtime_checks (bool) – Whether to enable runtime checks
- zmq_endpoint_base (str | None)
contains()
contains(request_id)
estimated_memory_size()
classmethod estimated_memory_size(params, max_batch_size, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)
Estimated memory size for the DPPagedKVCacheManager.
external_claim()
external_claim(request_id, replica_idx=None)
Reserve a sequence ID for the given request ID.
fetch()
fetch(batch, num_steps=1)
Fetch KV cache blocks for a batch of requests.
-
Parameters:
-
- batch (Sequence[TextGenerationContext]) – Batch of requests
- num_steps (int) – Number of steps to fetch
-
Return type:
free_blocks_pct
property free_blocks_pct: float
get_data_parallel_splits()
get_data_parallel_splits(context_batch)
Constructs splits for the data parallel execution.
-
Parameters:
-
context_batch (Sequence[TextGenerationContext]) – Sequence of requests. This must already be ordered by replica index (so contexts that are on the same replica are adjacent in the batch, and the replica must be in order).
-
Returns:
-
An int64 tensor with shape (self.num_replicas + 1) that contains the number of requests on each device: [0, num_requests_on_replica_0, num_requests_on_replica_1, …]
-
Return type:
get_or_recommend_replica()
get_or_recommend_replica(context)
Return idx of the replica that should be used for the given request.
-
Parameters:
-
context (TextGenerationContext)
-
Return type:
get_replica()
get_replica(context)
-
Parameters:
-
context (TextGenerationContext)
-
Return type:
get_req_blocks()
get_req_blocks(request_id)
host_committed_block_pct
property host_committed_block_pct: float
increment_cache_lengths()
increment_cache_lengths(kv_cache_inputs, prev_model_inputs)
Prepares cache inputs for the next token in multistep execution.
Updated to handle replicas
Updates the cache lengths for the next inference step without requiring device synchronization or memory copies. This is crucial for maintaining performance during multi-token generation.
-
Parameters:
-
- kv_cache_inputs (Sequence[RaggedKVCacheInputs]) – Current cache state tuples (blocks, lengths, lookup, max_lengths)
- prev_model_inputs (Any) – Previous model inputs including row offsets
-
Returns:
-
Updated cache input tuples with incremented lengths.
-
Return type:
infer_optimal_batch_size()
classmethod infer_optimal_batch_size(params, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)
input_symbols()
input_symbols(devices=None, num_layers=None)
-
Parameters:
-
Return type:
max_supported_sequence_length()
classmethod max_supported_sequence_length(params, num_layers, memory_available)
Return the maximum sequence length supported across all replicas.
This queries each data-parallel replica’s tensor-parallel cache manager
for its per-replica maximum supported sequence length under the provided
memory budget, then returns the minimum of those values. Each per-replica
value is already rounded down to the nearest multiple of params.page_size,
so the result is likewise page-aligned and safe for all replicas.
-
Parameters:
-
- params (KVCacheParams) – KV cache configuration parameters.
- num_layers (int) – Number of transformer layers contributing KV per token.
- memory_available (int) – Total cache memory budget in bytes.
-
Returns:
-
The maximum supported sequence length in tokens (multiple of
params.page_size) that all replicas can support. -
Return type:
maybe_reserve()
maybe_reserve(data, num_steps=1)
Prepares blocks for a request prior to a subsequent fetch call.
Reuses blocks from prefix cache and allocates new blocks for the request.
If a request is reserved, it’s guaranteed to not OOM in a subsequent call
to fetch.
-
Parameters:
-
- data (TextGenerationContext) – The text generation context for the request. The request ID must already be assigned to a replica via external_claim.
- num_steps (int) – The number of steps to reserve blocks for. Default: 1.
-
Returns:
-
True if the request was successfully reserved; false otherwise.
-
Return type:
metrics
property metrics: KVCacheMetrics
num_free_blocks
property num_free_blocks: int
Get the set of free blocks.
release()
release(request_id)
-
Parameters:
-
request_id (RequestID)
-
Return type:
-
None
reset_metrics()
reset_metrics()
-
Return type:
-
None
reset_prefix_cache()
reset_prefix_cache()
-
Return type:
-
None
step()
step(batch)
-
Parameters:
-
batch (Sequence[TextGenerationContext])
-
Return type:
-
None
total_num_host_pages
property total_num_host_pages: int
used_blocks_pct
property used_blocks_pct: float
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!