Skip to main content
Log in

Mojo struct

ContinuousBatchingKVCacheManager

struct ContinuousBatchingKVCacheManager[type: DType, kv_params: KVCacheStaticParams]

Manages a Batch-split KV cache across multiple user sessions.

Each request is assigned a seq_id, which is associated with a set of buffers to store the key and value projections per layer.

The order of calls for an active request is expected to be:

  • claim -- assigned blocks to the sequence and give it a unique id
  • step -- commit context encoding projections
  • foreach token generated:
    • fetch -- retrieve blocks based on a seq_id
    • step -- commit token generation projections
  • release -- mark blocks as not in use

TODO this is not currently threadsafe, make it so

Aliases

  • CollectionType = ContinuousBatchingKVCacheCollection[type, kv_params]:
  • BlocksType = NDBuffer[type, 6]:

Fields

  • blocks_buf (Tensor[type, 6]):
  • blocks_nd_buf (NDBuffer[type, 6]):
  • cache_lengths (Dict[Int, Int]):
  • num_blocks (Int):
  • max_batch_size (Int):
  • max_seq_len (Int):
  • seq_id_counter (Int):
  • inflight_batch (Optional[_ContinuousBatchingInflightBatchHandle]):
  • num_layers (Int):
  • this_device (Device):
  • other_device (Device):

Implemented traits

AnyType, KVCacheManagerT, UnknownDestructibility

Methods

__init__

__init__(out self, max_batch_size: Int, max_seq_len: Int, num_layers: Int, other_device: Device, this_device: Device)

claim

claim(mut self, batch_size: Int) -> List[Int]

Assign batch_size blocks for incoming requests.

This returns a List of seq_ids, which can be passed to fetch to retrieve a KVCacheCollection containing those sequences.

fetch

fetch[collection_t: KVCollectionT](mut self, seq_ids: List[Int]) -> collection_t

Retrieves the pre-assigned blocks for the given seq_ids.

if any of the seq_ids are not valid (e.g. no assigned blocks) then and error is raised.

step

step[collection_t: KVCollectionT](mut self, valid_lengths: List[Int], owned inflight_cache: collection_t)

Commits changes to the ContiguousKVCache blocks.

This is used to note that a KV projection step has occured and the values in these buffers have been written to. We note the new tokens in the blocks and update the valid_length counter.

release

release(mut self, seq_id: Int)

Marks seq_id as no longer necessary, their blocks are reintroduced to the pool.