Skip to main content

Mojo struct

ContinuousBatchingKVCacheCollection

struct ContinuousBatchingKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams]

This is a "view" of the cache for the given sequences in the batch.

This object does not own the underlying buffers in k_cache and v_cache, it's borrowing them from the BlockWrappers in our KVCacheManager.

Parameters​

  • ​dtype_ (DType): The dtype of the kv-cache.
  • ​kv_params_ (KVCacheStaticParams): The kv-cache static parameters.

Fields​

  • ​blocks (ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_tt_type):
  • ​cache_lengths (ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType.cache_lengths_tt_type):
  • ​lookup_table (ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType.lookup_table_tt_type):
  • ​max_seq_length (UInt32):
  • ​max_cache_length (UInt32):
  • ​kv_cache_dynamic_shape (IndexList[4]):
  • ​kv_cache_dynamic_strides (IndexList[4]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, KVCollectionT, Movable

comptime members​

blocks_layout​

comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_shape)

blocks_shape​

comptime blocks_shape = IntTuple(-1, -1, -1, -1, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params)

blocks_tt_layout​

comptime blocks_tt_layout = Layout[*?, *?]

blocks_tt_type​

comptime blocks_tt_type = TileTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout[*?, *?], MutAnyOrigin]

CacheType​

comptime CacheType = ContinuousBatchingKVCache[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params]

dtype​

comptime dtype = dtype_

kv_params​

comptime kv_params = kv_params_

name_str​

comptime name_str = "continuous_batching"

scale_dtype​

comptime scale_dtype = DType.invalid

Methods​

__init__​

__init__(out self, blocks: LayoutTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[LayoutTensor[DType.invalid, Layout.row_major[6](), MutAnyOrigin]] = None)

Construct from LayoutTensor params (MOGG boundary).

__init__(out self, blocks: TileTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout[*?, *?], MutAnyOrigin], cache_lengths: TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin], lookup_table: TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32)

Construct from TileTensor fields directly.

get_key_cache​

get_key_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType

Returns:

ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType

get_value_cache​

get_value_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType

Returns:

ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType

cache_length​

cache_length(self, bs_idx: Int) -> Int

Returns:

Int