Skip to main content

Python module

manager

Abstract base class for KVCacheManager for KV Cache.

ClaimedSlot

class max.nn.kv_cache.manager.ClaimedSlot(seq_id, replica_idx)

A class that represents a claimed sequence.

It is used to track the sequence ID and the replica index that claimed it. replica refers to the data parallelism replica of the cache that claimed the slot.

Parameters:

  • seq_id (int)
  • replica_idx (int)

replica_idx

replica_idx: int

seq_id

seq_id: int

KVCacheInputSymbols

class max.nn.kv_cache.manager.KVCacheInputSymbols

Base class for input symbols for KV cache managers.

The derived class is responsible for defining the input symbols for the specific KV cache manager. For example, here’s a derived class for a text KV cache manager:

@dataclass
class KVCacheInputSymbols(KVCacheInputSymbols):
    kv_blocks: TensorType
    cache_lengths: TensorType
    lookup_table: TensorType
    max_lengths: TensorType

KVCacheInputs

class max.nn.kv_cache.manager.KVCacheInputs

A base class that holds KV cache related (Tensor) inputs.

It is meant to be subclassed by concrete KV cache input types. For example, here’s a derived class for a text KV cache manager:

@dataclass
class RaggedKVCacheInputs(KVCacheInputs):
    blocks: Tensor
    cache_lengths: Tensor
    lookup_table: Tensor
    max_lengths: Tensor

KVCacheInputsSequence

class max.nn.kv_cache.manager.KVCacheInputsSequence(kv_cache_inputs)

KVCacheInputsSequence is a sequence of KVCacheInputs.

It is primarily used in our multistep execution to represent batched KVCacheInputs.

Parameters:

kv_cache_inputs (Sequence[KVCacheInputs])

kv_cache_inputs

kv_cache_inputs: Sequence[KVCacheInputs]

KVCacheManager

class max.nn.kv_cache.manager.KVCacheManager(params, max_batch_size, max_seq_len, num_layers, devices, session, is_ragged=False)

The base class for KV cache managers.

It is responsible for managing the KV cache for a given model.

Parameters:

  • params (KVCacheParams) – The parameters for the KV cache manager.
  • max_batch_size (int) – The maximum batch size. This should be the total overall maximum batch size, so if data parallelism is enabled, the sum of the batch size over all replicas will be equal to this value.
  • max_seq_len (int) – The maximum sequence length.
  • num_layers (int) – The number of layers.
  • devices (Sequence[Device]) – The devices to use for the KV cache manager.
  • session (InferenceSession) – The session to use for the KV cache manager.
  • is_ragged (bool) – Whether the KV cache manager is using ragged tensors.

params

params

The parameters for the KV cache manager.

max_batch_size

max_batch_size

The maximum batch size.

max_seq_len

max_seq_len

The maximum sequence length.

num_layers

num_layers

The number of layers.

contains()

contains(request_id)

Check if the given request ID is currently active in the cache.

Parameters:

request_id (str) – The request ID to check for.

Returns:

True if the request ID is active in the cache, False otherwise.

Return type:

bool

estimated_memory_size()

abstract classmethod estimated_memory_size(params, max_batch_size, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)

Returns the estimated total memory usage of the kv cache.

Parameters:

Return type:

int

external_claim()

external_claim(request_id)

Reserve a sequence ID for the given request ID.

Parameters:

request_id (str)

Return type:

None

fetch()

abstract fetch(batch, num_steps=1)

Returns blocks and other inputs to kv cache kernel for given sequence ids and prompts.

Parameters:

Return type:

Sequence[KVCacheInputs]

increment_cache_lengths()

increment_cache_lengths(kv_cache_inputs, prev_model_inputs)

Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution.

This should also not update the cache lengths in our manager, this batch is still considered in-progress.

Parameters:

Return type:

list[RaggedKVCacheInputs] | list[PaddedKVCacheInputs]

increment_cache_lengths_model

property increment_cache_lengths_model: Model

infer_optimal_batch_size()

abstract classmethod infer_optimal_batch_size(params, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)

Returns the estimated optimal batch size for the kv cache.

Parameters:

Return type:

int

input_symbols()

abstract input_symbols(devices=None, num_layers=None)

Returns the input symbols for the kv cache manager.

Parameters:

Return type:

Sequence[KVCacheInputSymbols]

release()

release(request_id)

Release the sequence associated with request_id, marking this sequence as complete. This returns the sequence ID back to the available pool of cache memory, allowing it to be reused when a new sequence is claimed.

Parameters:

request_id (str)

Return type:

None

step()

step(batch)

Commit the new tokens into the prefix cache.

This is a no-op if prefix caching is disabled.

Parameters:

batch (Sequence[T])

Return type:

None

PaddedKVCacheInputs

class max.nn.kv_cache.manager.PaddedKVCacheInputs(k_cache, v_cache, start_pos, null_op)

PaddedKVCacheInputs is a class that holds the inputs for KV cache when used together with padded tensors.

Parameters:

k_cache

k_cache: Tensor

null_op

null_op: Tensor

start_pos

start_pos: Tensor

v_cache

v_cache: Tensor

RaggedKVCacheInputs

class max.nn.kv_cache.manager.RaggedKVCacheInputs(blocks, cache_lengths, lookup_table, max_lengths)

RaggedKVCacheInputs is a class that holds the inputs for KV cache when used together with ragged tensors.

Parameters:

blocks

blocks: Tensor

cache_lengths

cache_lengths: Tensor

lookup_table

lookup_table: Tensor

max_lengths

max_lengths: Tensor

Was this page helpful?