Mojo struct
ContinuousBatchingKVCacheCollection
struct ContinuousBatchingKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams]
This is a "view" of the cache for the given sequences in the batch.
This object does not own the underlying buffers in k_cache and v_cache, it's borrowing them from the BlockWrappers in our KVCacheManager.
Parameters
- dtype_ (
DType): The dtype of the kv-cache. - kv_params_ (
KVCacheStaticParams): The kv-cache static parameters.
Fields
- blocks (
ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_tt_type): - cache_lengths (
ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType.cache_lengths_tt_type): - lookup_table (
ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType.lookup_table_tt_type): - max_seq_length (
UInt32): - max_cache_length (
UInt32): - kv_cache_dynamic_shape (
IndexList[4]): - kv_cache_dynamic_strides (
IndexList[4]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
KVCollectionT,
Movable
comptime members
blocks_layout
comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_shape)
blocks_shape
comptime blocks_shape = IntTuple(-1, -1, -1, -1, Int[UInt](kv_params_.num_heads), Int[UInt](kv_params_.head_size))
blocks_tt_layout
comptime blocks_tt_layout = Layout[#kgen.param_list.reduce(#kgen.param_list.tabulate(len[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: KGENParamList[CoordLike], VA: KGENParamList[CoordLike], idx: __mlir_type.index] #kgen.param_list.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)[idx])] if (Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)[idx]) != -1) else RuntimeInt[DType.int64])), #kgen.param_list.reduce(#kgen.param_list.tabulate(len[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: KGENParamList[CoordLike], VA: KGENParamList[CoordLike], idx: __mlir_type.index] #kgen.param_list.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)[idx])] if (Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)[idx]) != -1) else RuntimeInt[DType.int64]))]
blocks_tt_type
comptime blocks_tt_type = TileTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout[#kgen.param_list.reduce(#kgen.param_list.tabulate(len[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: KGENParamList[CoordLike], VA: KGENParamList[CoordLike], idx: __mlir_type.index] #kgen.param_list.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)[idx])] if (Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)[idx]) != -1) else RuntimeInt[DType.int64])), #kgen.param_list.reduce(#kgen.param_list.tabulate(len[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: KGENParamList[CoordLike], VA: KGENParamList[CoordLike], idx: __mlir_type.index] #kgen.param_list.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)[idx])] if (Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)[idx]) != -1) else RuntimeInt[DType.int64]))], MutAnyOrigin]
CacheType
comptime CacheType = ContinuousBatchingKVCache[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params]
dtype
comptime dtype = dtype_
kv_params
comptime kv_params = kv_params_
name_str
comptime name_str = "continuous_batching"
scale_dtype
comptime scale_dtype = DType.invalid
Methods
__init__
__init__(out self, blocks: LayoutTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[LayoutTensor[DType.invalid, Layout.row_major[6](), MutAnyOrigin]] = None)
Construct from LayoutTensor params (MOGG boundary).
__init__(out self, blocks: TileTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout[#kgen.param_list.reduce(#kgen.param_list.tabulate(len[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: KGENParamList[CoordLike], VA: KGENParamList[CoordLike], idx: __mlir_type.index] #kgen.param_list.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)[idx])] if (Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.shape)[idx]) != -1) else RuntimeInt[DType.int64])), #kgen.param_list.reduce(#kgen.param_list.tabulate(len[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: KGENParamList[CoordLike], VA: KGENParamList[CoordLike], idx: __mlir_type.index] #kgen.param_list.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)[idx])] if (Int[IntTuple](product_each(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout.stride)[idx]) != -1) else RuntimeInt[DType.int64]))], MutAnyOrigin], cache_lengths: TileTensor[DType.uint32, Layout[RuntimeInt[DType.int64], ComptimeInt[1]], ImmutAnyOrigin], lookup_table: TileTensor[DType.uint32, Layout[RuntimeInt[DType.int64], ComptimeInt[1]], ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32)
Construct from TileTensor fields directly.
get_key_cache
get_key_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType
Returns:
ContinuousBatchingKVCacheCollection
get_value_cache
get_value_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType
Returns:
ContinuousBatchingKVCacheCollection
cache_length
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!