Skip to main content

Python module

cache_manager

PagedKVCacheManager

class max.kv_cache.paged_cache.cache_manager.PagedKVCacheManager(params, total_num_pages, devices, session, total_num_host_pages=0, enable_runtime_checks=False)

Paged KVCache manager with data and tensor parallelism support.

kv_manager.claim(ctx1.request_id, replica_idx=0)
kv_manager.claim(ctx2.request_id, replica_idx=1)

# Allocate blocks for these requests
kv_manager.alloc(ctx1, num_steps=10)
kv_manager.alloc(ctx2, num_steps=10)

# Get KVCache inputs to feed to graph
kv_cache_inputs = kv_manager.get_runtime_inputs([ctx1, ctx2], num_steps=10)

# Run model...
# Update requests with newly generated tokens
ctx1.update(42)
ctx2.update(42)

# Commit newly written blocks to prefix cache
kv_manager.step([ctx1, ctx2])

# Release metadata and KV blocks for these requests
kv_manager.release(ctx1.request_id)
kv_manager.release(ctx2.request_id)

Initialize the multi-device paged KV cache manager.

Parameters:

  • params (KVCacheParams) – KV cache parameters including data parallelism settings
  • devices (Sequence[Device]) – The devices to use for the KV cache manager. If data parallelism is enabled, the devices will be split into params.data_parallel_degree groups.
  • session (InferenceSession) – Inference session
  • enable_runtime_checks (bool) – Whether to enable runtime checks
  • total_num_pages (int)
  • total_num_host_pages (int)

alloc()

alloc(data, num_steps=1)

Allocates blocks for a request to run for N steps.

This method allocates blocks needed by a request to run for N steps. When prefix caching is enabled, some of the allocated blocks may be retrieved from the prefix cache.

Parameters:

  • data (TextGenerationContext) – The text generation context for the request. The request ID must already be assigned to a replica via claim.
  • num_steps (int) – The number of steps to reserve blocks for. Default: 1.

Raises:

  • InsufficientBlocksError – If there are insufficient free blocks to
  • satisfy the allocation.

Return type:

None

claim()

claim(request_id, replica_idx=None)

Reserve a sequence ID for the given request ID.

Parameters:

Return type:

None

contains()

contains(request_id)

Parameters:

request_id (RequestID)

Return type:

bool

device_tensors

property device_tensors: list[list[Tensor]]

free_blocks_pct

property free_blocks_pct: float

get_or_recommend_replica()

get_or_recommend_replica(context)

Return idx of the replica that should be used for the given request.

Parameters:

context (TextGenerationContext)

Return type:

int

get_pct_used_blocks_after_allocation()

get_pct_used_blocks_after_allocation(ctx, num_steps=1)

Get the percentage of blocks used after allocating for a request.

Parameters:

  • ctx (TextGenerationContext) – The request context containing sequence information and token indices.
  • num_steps (int) – Number of additional steps to allocate blocks for. Defaults to 1.

Returns:

The percentage of total blocks used after allocating for the request.

Return type:

float

get_replica()

get_replica(request_id)

Parameters:

request_id (RequestID)

Return type:

int

get_req_blocks()

get_req_blocks(request_id)

Parameters:

request_id (RequestID)

Return type:

list[int]

get_runtime_inputs()

get_runtime_inputs(batch, num_steps=1)

Get the graph inputs for a batch of requests.

This method will raise a RuntimeError if any request has insufficient blocks already allocated to it to run for the given number of steps.

Parameters:

Return type:

list[RaggedKVCacheInputs]

get_symbolic_inputs()

get_symbolic_inputs(devices=None, num_layers=None)

Parameters:

Return type:

Sequence[PagedCacheInputSymbols]

host_committed_block_pct

property host_committed_block_pct: float

increment_cache_lengths()

increment_cache_lengths(kv_cache_inputs, prev_model_inputs)

Prepares cache inputs for the next token in multistep execution.

Updates the cache lengths for the next inference step without requiring device synchronization or memory copies. This is crucial for maintaining performance during multi-token generation.

Parameters:

  • kv_cache_inputs (Sequence[RaggedKVCacheInputs]) – Current cache state tuples (blocks, lengths, lookup, max_lengths)
  • prev_model_inputs (Any) – Previous model inputs including row offsets

Returns:

Updated cache input tuples with incremented lengths.

Return type:

Sequence[RaggedKVCacheInputs]

infer_optimal_batch_size()

classmethod infer_optimal_batch_size(params, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)

Parameters:

Return type:

int

metrics

property metrics: KVCacheMetrics

num_free_blocks

property num_free_blocks: int

Get the set of free blocks.

release()

release(request_id)

Parameters:

request_id (RequestID)

Return type:

None

reset_metrics()

reset_metrics()

Return type:

None

reset_prefix_cache()

reset_prefix_cache()

Return type:

None

step()

step(batch)

Parameters:

batch (Sequence[TextGenerationContext])

Return type:

None

total_num_host_pages

property total_num_host_pages: int

used_blocks_pct

property used_blocks_pct: float

Was this page helpful?