Mojo struct
ContinuousBatchingKVCacheManager
struct ContinuousBatchingKVCacheManager[type: DType, kv_params: KVCacheStaticParams]
Manages a Batch-split KV cache across multiple user sessions.
Each request is assigned a seq_id, which is associated with a set of buffers to store the key and value projections per layer.
The order of calls for an active request is expected to be:
- claim -- assigned blocks to the sequence and give it a unique id
- step -- commit context encoding projections
- foreach token generated:
- fetch -- retrieve blocks based on a seq_id
- step -- commit token generation projections
- release -- mark blocks as not in use
TODO this is not currently threadsafe, make it so
Aliases
CollectionType = ContinuousBatchingKVCacheCollection[type, kv_params]
:BlocksType = NDBuffer[type, 6]
:
Fields
- blocks_buf (
Tensor[type, 6]
): - blocks_nd_buf (
NDBuffer[type, 6]
): - cache_lengths (
Dict[Int, Int]
): - num_blocks (
Int
): - max_batch_size (
Int
): - max_seq_len (
Int
): - seq_id_counter (
Int
): - inflight_batch (
Optional[_ContinuousBatchingInflightBatchHandle]
): - num_layers (
Int
): - this_device (
Device
): - other_device (
Device
):
Implemented traits
AnyType
,
KVCacheManagerT
,
UnknownDestructibility
Methods
__init__
__init__(out self, max_batch_size: Int, max_seq_len: Int, num_layers: Int, other_device: Device, this_device: Device)
claim
claim(mut self, batch_size: Int) -> List[Int]
Assign batch_size
blocks for incoming requests.
This returns a List of seq_ids, which can be passed to fetch
to
retrieve a KVCacheCollection containing those sequences.
fetch
fetch[collection_t: KVCollectionT](mut self, seq_ids: List[Int]) -> collection_t
Retrieves the pre-assigned blocks for the given seq_ids.
if any of the seq_ids are not valid (e.g. no assigned blocks) then and error is raised.
step
step[collection_t: KVCollectionT](mut self, valid_lengths: List[Int], owned inflight_cache: collection_t)
Commits changes to the ContiguousKVCache blocks.
This is used to note that a KV projection step has occured and the values in these buffers have been written to. We note the new tokens in the blocks and update the valid_length counter.
release
release(mut self, seq_id: Int)
Marks seq_id
as no longer necessary, their blocks are reintroduced to the pool.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!