Skip to main content
Log in

Python module

manager

Abstract base class for KVCacheManager for KV Cache.

KVCacheInputSymbols

class max.nn.kv_cache.manager.KVCacheInputSymbols

Base class for input symbols for KV cache managers.

The derived class is responsible for defining the input symbols for the specific KV cache manager. For example, here’s a derived class for a text KV cache manager:

@dataclass
class ContinuousBatchingKVCacheInputSymbols(KVCacheInputSymbols):
kv_blocks: TensorType
cache_lengths: TensorType
lookup_table: TensorType
max_lengths: TensorType
@dataclass
class ContinuousBatchingKVCacheInputSymbols(KVCacheInputSymbols):
kv_blocks: TensorType
cache_lengths: TensorType
lookup_table: TensorType
max_lengths: TensorType

KVCacheInputs

class max.nn.kv_cache.manager.KVCacheInputs

A base class that holds KV cache related (Tensor) inputs.

It is meant to be subclassed by concrete KV cache input types. For example, here’s a derived class for a text KV cache manager:

@dataclass
class RaggedKVCacheInputs(KVCacheInputs):
blocks: Tensor
cache_lengths: Tensor
lookup_table: Tensor
max_lengths: Tensor
@dataclass
class RaggedKVCacheInputs(KVCacheInputs):
blocks: Tensor
cache_lengths: Tensor
lookup_table: Tensor
max_lengths: Tensor

KVCacheInputsSequence

class max.nn.kv_cache.manager.KVCacheInputsSequence(kv_cache_inputs)

KVCacheInputsSequence is a sequence of KVCacheInputs.

It is primarily used in our multistep execution to represent batched KVCacheInputs.

Parameters:

kv_cache_inputs (Sequence [ KVCacheInputs ] )

kv_cache_inputs

kv_cache_inputs*: Sequence[KVCacheInputs]*

KVCacheManager

class max.nn.kv_cache.manager.KVCacheManager(params, max_batch_size, max_seq_len, num_layers, devices, session, is_ragged=False)

Parameters:

claim()

claim(n)

Claims n blocks of memory in the cache for incoming requests.

This returns a list of sequence ids, which identify a sequence’s location within the cache. This sequence id can then be passed in the fetch function to return the ContinuousBatchingKVCacheCollection for those sequences.

Parameters:

n (int )

Return type:

list[int]

contains()

contains(seq_id)

Parameters:

seq_id (int )

Return type:

bool

estimated_memory_size()

abstract classmethod estimated_memory_size(params, max_batch_size, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)

Returns the estimated total memory usage of the kv cache.

Parameters:

Return type:

int

external_claim()

external_claim(seq_ids)

Variant of the above where sequence ids are reserved externally.

Parameters:

seq_ids (list [ int ] )

Return type:

None

fetch()

abstract fetch(batch, num_steps=1)

Returns blocks and other inputs to kv cache kernel for given sequence ids and prompts.

Parameters:

  • batch (list [ T ] )
  • num_steps (int )

Return type:

list[KVCacheInputs]

increment_cache_lengths()

increment_cache_lengths(kv_cache_inputs, prev_model_inputs)

Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution.

This should also not update the cache lengths in our manager, this batch is still considered in-progress.

Parameters:

Return type:

list[RaggedKVCacheInputs] | list[PaddedKVCacheInputs]

infer_optimal_batch_size()

abstract classmethod infer_optimal_batch_size(params, max_seq_len, num_layers, available_cache_memory, devices, **kwargs)

Returns the estimated optimal batch size for the kv cache.

Parameters:

Return type:

int

input_symbols()

abstract input_symbols()

Returns the input symbols for the kv cache manager.

Return type:

Sequence[KVCacheInputSymbols]

num_kv_inputs()

num_kv_inputs()

Returns the default number of KV cache inputs for KV managers.

Subclasses with a different number of KV cache inputs should override this method and increment_cache_lengths.

Return type:

int

release()

release(seq_id)

Release seq_id provided, marking this sequence as complete. This returns the seq_id back to the available pool of cache memory, allowing it to be reused when a new sequence is claimed.

Parameters:

seq_id (int )

Return type:

None

slots_remaining

property slots_remaining*: set[int]*

The outstanding cache slots available.

step()

step(batch)

Commit the new tokens into the prefix cache.

This is a no-op if prefix caching is disabled.

Parameters:

batch (list [ T ] )

Return type:

None

PaddedKVCacheInputs

class max.nn.kv_cache.manager.PaddedKVCacheInputs(k_cache, v_cache, start_pos, null_op)

PaddedKVCacheInputs is a class that holds the inputs for KV cache when used together with padded tensors.

Parameters:

k_cache

k_cache*: Tensor*

null_op

null_op*: Tensor*

start_pos

start_pos*: Tensor*

v_cache

v_cache*: Tensor*

RaggedKVCacheInputs

class max.nn.kv_cache.manager.RaggedKVCacheInputs(blocks, cache_lengths, lookup_table, max_lengths)

RaggedKVCacheInputs is a class that holds the inputs for KV cache when used together with ragged tensors.

Parameters:

blocks

blocks*: Tensor*

cache_lengths

cache_lengths*: Tensor*

lookup_table

lookup_table*: Tensor*

max_lengths

max_lengths*: Tensor*