Mojo struct
ContinuousBatchingKVCache
@register_passable(trivial)
struct ContinuousBatchingKVCache[dtype_: DType, kv_params_: KVCacheStaticParams]
Wrapper for the ContinuousKVCache of a given layer in the transformer model.
This abstracts the Pointer indirection for accessing the ContinuousKVCache for a given batch entry.
THIS IS THE TYPE THAT IS PASSED TO KV PROJECTION AND FLASH ATTENTION KERNELS.
Parameters
- dtype_ (
DType): The dtype of the kv-cache. - kv_params_ (
KVCacheStaticParams): The kv-cache static parameters. 
Fields
- blocks (
NDBuffer[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, 4, MutableAnyOrigin, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_shape, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_stride]): - cache_lengths (
NDBuffer[DType.uint32, 1, MutableAnyOrigin]): - lookup_table (
NDBuffer[DType.uint32, 1, MutableAnyOrigin]): - max_seq_length (
UInt32): - max_cache_length (
UInt32): 
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
KVCacheT,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
blocks_shape
alias blocks_shape = DimList.__init__[Dim, Dim, Dim, Dim](Dim(), Dim(), Dim(Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params)), Dim(Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params)))
blocks_stride
alias blocks_stride = _strides_from_shape[ContinuousBatchingKVCache[dtype_, kv_params_].blocks_shape, skip=1]()
blocks_type
alias blocks_type = NDBuffer[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, 4, MutableAnyOrigin, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_shape, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_stride]
device_type
alias device_type = ContinuousBatchingKVCache[dtype_, kv_params_]
dtype
alias dtype = dtype_
kv_params
alias kv_params = kv_params_
Methods
__init__
__init__(blocks: NDBuffer[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, 4, MutableAnyOrigin, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_shape, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_stride], cache_lengths: NDBuffer[DType.uint32, 1, MutableAnyOrigin], lookup_table: NDBuffer[DType.uint32, 1, MutableAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32) -> Self
get_type_name
static get_type_name() -> String
Returns:
String
get_device_type_name
static get_device_type_name() -> String
Returns:
String
max_tile_size
cache_lengths_nd
cache_length
load
load[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, width]
Returns:
store
store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, size])
empty_cache
empty_cache(self) -> Bool
Returns true if the cache_lengths for all requests is 0, false otherwise.
Returns:
max_prompt_length
max_prompt_length(self) -> UInt32
Returns the maximum sequence length across all batches of the current request.
Returns:
max_context_length
max_context_length(self) -> UInt32
Returns the maximum cache length used across all batches of the current request.
Returns:
row_idx
row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
col_idx
col_idx(self, head_idx: UInt32) -> UInt32
Returns the col idx when viewing the memory as a matrix.
Returns:
create_tma_tile
create_tma_tile[tile_m: Int, tile_n: Int, swizzle_mode: TensorMapSwizzle, *, is_k_major: Bool](self, ctx: DeviceContext) -> TMATensorTile[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, tile_layout_k_major[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, tile_m, tile_n, swizzle_mode]() if is_k_major else tile_layout_mn_major[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, tile_n, tile_m, swizzle_mode](), _tma_desc_tile_layout[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, 2, IndexList[2, DType.int64](tile_m, tile_n, Tuple[]()), is_k_major, swizzle_mode](), is_k_major]
Creates a TMA tile for this KV cache.
Returns:
TMATensorTile
block_paged_ptr
block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[ContinuousBatchingKVCache[dtype_, kv_params_].dtype]]
Returns:
UnsafePointer
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!