Mojo struct
ContiguousKVCacheManager
struct ContiguousKVCacheManager[type: DType, kv_params: KVCacheStaticParams]
Manages a Batch-split KV cache across multiple user sessions.
Each request is assigned a seq_id, which is associated with a set of buffers to store the key and value projections per layer.
The order of calls for an active request is expected to be:
- claim -- assigned blocks to the sequence and give it a unique id
- step -- commit context encoding projections
- foreach token generated:
- fetch -- retrieve blocks based on a seq_id
- step -- commit token generation projections
- release -- mark blocks as not in use
TODO this is not currently threadsafe, make it so
Aliases
CollectionType = ContiguousKVCacheCollection[type, kv_params]
:
Fields
- blocks_buf (
Tensor[type, 6]
): - num_blocks (
Int
): - max_batch_size (
Int
): - max_seq_len (
Int
): - active_seq_ids (
List[Int]
): - cache_lengths (
List[Int]
): - seq_id_counter (
Int
): - num_layers (
Int
): - other_device (
Device
): - this_device (
Device
): - cache_lengths_tensor_host (
DeviceTensor
): - cache_lengths_tensor_dev (
DeviceTensor
):
Implemented traits
AnyType
,
KVCacheManagerT
,
UnknownDestructibility
Methods
__init__
__init__(out self, max_batch_size: Int, max_seq_len: Int, num_layers: Int, mut other_device: Device, mut this_device: Device)
claim
claim(mut self, batch_size: Int) -> List[Int]
Assign batch_size
blocks for incoming requests.
This returns a List of seq_ids, which can be passed to fetch
to
retrieve the KVCollection for the given batch.
fetch
fetch[collection_t: KVCollectionT](mut self, seq_ids: List[Int]) -> collection_t
Retrieves the pre-assigned blocks for the given seq_ids.
if any of the seq_ids are not valid (e.g. no assigned blocks) then and error is raised.
step
step[collection_t: KVCollectionT](mut self, valid_lengths: List[Int], owned inflight_cache: collection_t)
Commits changes to the ContiguousKVCache blocks.
This is used to note that a KV projection step has occured and the values in these buffers have been written to. We note the new tokens in the blocks and update the valid_length counter.
release
release(mut self, seq_id: Int)
Marks seq_id
as no longer necessary, their blocks are reintroduced to the pool.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!