Python module
manager
Abstract base class for KVCacheManager for KV Cache.
ClaimedSlot
class max.nn.kv_cache.manager.ClaimedSlot(seq_id, replica_idx)
A class that represents a claimed sequence.
It is used to track the sequence ID and the replica index that claimed it. replica refers to the data parallelism replica of the cache that claimed the slot.
replica_idx
replica_idx: int
seq_id
seq_id: int
KVCacheInputSymbols
class max.nn.kv_cache.manager.KVCacheInputSymbols
Base class for input symbols for KV cache managers.
The derived class is responsible for defining the input symbols for the specific KV cache manager. For example, here’s a derived class for a text KV cache manager:
@dataclass
class KVCacheInputSymbols(KVCacheInputSymbols):
kv_blocks: TensorType
cache_lengths: TensorType
lookup_table: TensorType
max_lengths: TensorType
KVCacheInputs
class max.nn.kv_cache.manager.KVCacheInputs
A base class that holds KV cache related (Tensor) inputs.
It is meant to be subclassed by concrete KV cache input types. For example, here’s a derived class for a text KV cache manager:
@dataclass
class RaggedKVCacheInputs(KVCacheInputs):
blocks: Tensor
cache_lengths: Tensor
lookup_table: Tensor
max_lengths: Tensor
KVCacheInputsSequence
class max.nn.kv_cache.manager.KVCacheInputsSequence(kv_cache_inputs)
KVCacheInputsSequence
is a sequence of KVCacheInputs
.
It is primarily used in our multistep execution to represent batched KVCacheInputs.
-
Parameters:
-
kv_cache_inputs (Sequence[KVCacheInputs])
kv_cache_inputs
kv_cache_inputs: Sequence[KVCacheInputs]
KVCacheManager
class max.nn.kv_cache.manager.KVCacheManager(params, max_batch_size, max_seq_len, num_layers, devices, session, is_ragged=False)
The base class for KV cache managers.
It is responsible for managing the KV cache for a given model.
-
Parameters:
-
- params (KVCacheParams) – The parameters for the KV cache manager.
- max_batch_size (int) – The maximum batch size. This should be the total overall maximum batch size, so if data parallelism is enabled, the sum of the batch size over all replicas will be equal to this value.
- max_seq_len (int) – The maximum sequence length.
- num_layers (int) – The number of layers.
- devices (Sequence[Device]) – The devices to use for the KV cache manager.
- session (InferenceSession) – The session to use for the KV cache manager.
- is_ragged (bool) – Whether the KV cache manager is using ragged tensors.
params
params
The parameters for the KV cache manager.
max_batch_size
max_batch_size
The maximum batch size.
max_seq_len
max_seq_len
The maximum sequence length.
num_layers
num_layers
The number of layers.
contains()
contains(request_id)
Check if the given request ID is currently active in the cache.
estimated_memory_size()
abstract classmethod estimated_memory_size(params, max_batch_size, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)
Returns the estimated total memory usage of the kv cache.
external_claim()
external_claim(request_id)
Reserve a sequence ID for the given request ID.
-
Parameters:
-
request_id (str)
-
Return type:
-
None
fetch()
abstract fetch(batch, num_steps=1)
Returns blocks and other inputs to kv cache kernel for given sequence ids and prompts.
-
Parameters:
-
Return type:
increment_cache_lengths()
increment_cache_lengths(kv_cache_inputs, prev_model_inputs)
Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution.
This should also not update the cache lengths in our manager, this batch is still considered in-progress.
-
Parameters:
-
- kv_cache_inputs (list[RaggedKVCacheInputs] | list[PaddedKVCacheInputs])
- prev_model_inputs (Any)
-
Return type:
increment_cache_lengths_model
property increment_cache_lengths_model: Model
infer_optimal_batch_size()
abstract classmethod infer_optimal_batch_size(params, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)
Returns the estimated optimal batch size for the kv cache.
input_symbols()
abstract input_symbols(devices=None, num_layers=None)
Returns the input symbols for the kv cache manager.
-
Parameters:
-
Return type:
release()
release(request_id)
Release the sequence associated with request_id
, marking this sequence as complete.
This returns the sequence ID back to the available pool of cache memory,
allowing it to be reused when a new sequence is claimed.
-
Parameters:
-
request_id (str)
-
Return type:
-
None
step()
step(batch)
Commit the new tokens into the prefix cache.
This is a no-op if prefix caching is disabled.
-
Parameters:
-
batch (Sequence[T])
-
Return type:
-
None
PaddedKVCacheInputs
class max.nn.kv_cache.manager.PaddedKVCacheInputs(k_cache, v_cache, start_pos, null_op)
PaddedKVCacheInputs
is a class that holds the inputs for
KV cache when used together with padded tensors.
k_cache
k_cache: Tensor
null_op
null_op: Tensor
start_pos
start_pos: Tensor
v_cache
v_cache: Tensor
RaggedKVCacheInputs
class max.nn.kv_cache.manager.RaggedKVCacheInputs(blocks, cache_lengths, lookup_table, max_lengths)
RaggedKVCacheInputs
is a class that holds the inputs for
KV cache when used together with ragged tensors.
blocks
blocks: Tensor
cache_lengths
cache_lengths: Tensor
lookup_table
lookup_table: Tensor
max_lengths
max_lengths: Tensor
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!