Python module
manager
Abstract base class for KVCacheManager for KV Cache.
KVCacheInputSymbols
class max.nn.kv_cache.manager.KVCacheInputSymbols
Base class for input symbols for KV cache managers.
The derived class is responsible for defining the input symbols for the specific KV cache manager. For example, here’s a derived class for a text KV cache manager:
@dataclass
class ContinuousBatchingKVCacheInputSymbols(KVCacheInputSymbols):
kv_blocks: TensorType
cache_lengths: TensorType
lookup_table: TensorType
max_lengths: TensorType
@dataclass
class ContinuousBatchingKVCacheInputSymbols(KVCacheInputSymbols):
kv_blocks: TensorType
cache_lengths: TensorType
lookup_table: TensorType
max_lengths: TensorType
KVCacheInputs
class max.nn.kv_cache.manager.KVCacheInputs
A base class that holds KV cache related (Tensor) inputs.
It is meant to be subclassed by concrete KV cache input types. For example, here’s a derived class for a text KV cache manager:
@dataclass
class RaggedKVCacheInputs(KVCacheInputs):
blocks: Tensor
cache_lengths: Tensor
lookup_table: Tensor
max_lengths: Tensor
@dataclass
class RaggedKVCacheInputs(KVCacheInputs):
blocks: Tensor
cache_lengths: Tensor
lookup_table: Tensor
max_lengths: Tensor
KVCacheInputsSequence
class max.nn.kv_cache.manager.KVCacheInputsSequence(kv_cache_inputs)
KVCacheInputsSequence
is a sequence of KVCacheInputs
.
It is primarily used in our multistep execution to represent batched KVCacheInputs.
-
Parameters:
-
kv_cache_inputs (
Sequence
[
KVCacheInputs
]
)
kv_cache_inputs
kv_cache_inputs*: Sequence[KVCacheInputs]*
KVCacheManager
class max.nn.kv_cache.manager.KVCacheManager(params, max_batch_size, max_seq_len, num_layers, devices, session, is_ragged=False)
-
Parameters:
-
- params (
KVCacheParams
) - max_batch_size (
int
) - max_seq_len (
int
) - num_layers (
int
) - devices (
Sequence
[
Device
]
) - session (
InferenceSession
) - is_ragged (
bool
)
- params (
claim()
claim(n)
Claims n
blocks of memory in the cache for incoming requests.
This returns a list of sequence ids, which identify a sequence’s
location within the cache. This sequence id can then be passed
in the fetch function to return the ContinuousBatchingKVCacheCollection
for those sequences.
contains()
contains(seq_id)
estimated_memory_size()
abstract classmethod estimated_memory_size(params, max_batch_size, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)
Returns the estimated total memory usage of the kv cache.
external_claim()
external_claim(seq_ids)
Variant of the above where sequence ids are reserved externally.
fetch()
abstract fetch(batch, num_steps=1)
Returns blocks and other inputs to kv cache kernel for given sequence ids and prompts.
-
Parameters:
-
Return type:
increment_cache_lengths()
increment_cache_lengths(kv_cache_inputs, prev_model_inputs)
Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution.
This should also not update the cache lengths in our manager, this batch is still considered in-progress.
-
Parameters:
-
- kv_cache_inputs (
list
[
RaggedKVCacheInputs
]
|
list
[
PaddedKVCacheInputs
]
) - prev_model_inputs (
Any
)
- kv_cache_inputs (
-
Return type:
infer_optimal_batch_size()
abstract classmethod infer_optimal_batch_size(params, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)
Returns the estimated optimal batch size for the kv cache.
input_symbols()
abstract input_symbols()
Returns the input symbols for the kv cache manager.
-
Return type:
num_kv_inputs()
num_kv_inputs()
Returns the default number of KV cache inputs for KV managers.
Subclasses with a different number of KV cache inputs should override
this method and increment_cache_lengths
.
-
Return type:
release()
release(seq_id)
Release seq_id
provided, marking this sequence as complete.
This returns the seq_id
back to the available pool of cache memory,
allowing it to be reused when a new sequence is claimed.
-
Parameters:
-
seq_id (
int
) -
Return type:
-
None
slots_remaining
The outstanding cache slots available.
step()
step(batch)
Commit the new tokens into the prefix cache.
This is a no-op if prefix caching is disabled.
-
Parameters:
-
batch (
list
[
T
]
) -
Return type:
-
None
PaddedKVCacheInputs
class max.nn.kv_cache.manager.PaddedKVCacheInputs(k_cache, v_cache, start_pos, null_op)
PaddedKVCacheInputs
is a class that holds the inputs for
KV cache when used together with padded tensors.
k_cache
k_cache*: Tensor*
null_op
null_op*: Tensor*
start_pos
start_pos*: Tensor*
v_cache
v_cache*: Tensor*
RaggedKVCacheInputs
class max.nn.kv_cache.manager.RaggedKVCacheInputs(blocks, cache_lengths, lookup_table, max_lengths)
RaggedKVCacheInputs
is a class that holds the inputs for
KV cache when used together with ragged tensors.
blocks
blocks*: Tensor*
cache_lengths
cache_lengths*: Tensor*
lookup_table
lookup_table*: Tensor*
max_lengths
max_lengths*: Tensor*
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!