Mojo struct
ContinuousBatchingKVCache
@register_passable(trivial)
struct ContinuousBatchingKVCache[dtype_: DType, kv_params_: KVCacheStaticParams]
Wrapper for the ContinuousKVCache of a given layer in the transformer model.
This abstracts the Pointer indirection for accessing the ContinuousKVCache for a given batch entry.
THIS IS THE TYPE THAT IS PASSED TO KV PROJECTION AND FLASH ATTENTION KERNELS.
Parameters
- dtype_ (
DType
): The dtype of the kv-cache. - kv_params_ (
KVCacheStaticParams
): The kv-cache static parameters.
Fields
- blocks (
NDBuffer[dtype_, 4, MutableAnyOrigin, DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size)), _strides_from_shape[DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size)), skip=1]()]
): - cache_lengths (
NDBuffer[DType.uint32, 1, MutableAnyOrigin]
): - lookup_table (
NDBuffer[DType.uint32, 1, MutableAnyOrigin]
): - max_seq_length (
UInt32
): - max_cache_length (
UInt32
):
Implemented traits
AnyType
,
ExplicitlyCopyable
,
ImplicitlyCopyable
,
KVCacheT
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
blocks_shape
alias blocks_shape = DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size))
blocks_stride
alias blocks_stride = _strides_from_shape[DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size)), skip=1]()
blocks_type
alias blocks_type = NDBuffer[dtype_, 4, MutableAnyOrigin, DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size)), _strides_from_shape[DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size)), skip=1]()]
dtype
alias dtype = dtype_
kv_params
alias kv_params = kv_params_
Methods
__init__
__init__(blocks: NDBuffer[dtype_, 4, MutableAnyOrigin, DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size)), _strides_from_shape[DimList.__init__[Dim, Dim, Dim, Dim](Dim(-31337), Dim(-31337), Dim(kv_params_.num_heads), Dim(kv_params_.head_size)), skip=1]()], cache_lengths: NDBuffer[DType.uint32, 1, MutableAnyOrigin], lookup_table: NDBuffer[DType.uint32, 1, MutableAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32) -> Self
max_tile_size
cache_lengths_nd
cache_length
load
load[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[dtype_, width]
Returns:
store
store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[dtype_, size])
empty_cache
empty_cache(self) -> Bool
Returns true if the cache_lengths for all requests is 0, false otherwise.
Returns:
max_prompt_length
max_prompt_length(self) -> UInt32
Returns the maximum sequence length across all batches of the current request.
Returns:
max_context_length
max_context_length(self) -> UInt32
Returns the maximum cache length used across all batches of the current request.
Returns:
row_idx
row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
col_idx
col_idx(self, head_idx: UInt32) -> UInt32
Returns the col idx when viewing the memory as a matrix.
Returns:
create_tma_tile
create_tma_tile[tile_m: Int, tile_n: Int, swizzle_mode: TensorMapSwizzle, *, is_k_major: Bool](self, ctx: DeviceContext) -> TMATensorTile[dtype_, tile_layout_k_major[dtype_, tile_m, tile_n, swizzle_mode]() if is_k_major else tile_layout_mn_major[dtype_, tile_n, tile_m, swizzle_mode](), _tma_desc_tile_layout[dtype_, 2, IndexList[2, DType.int64](tile_m, tile_n, Tuple[]()), is_k_major, swizzle_mode](), is_k_major]
Creates a TMA tile for this KV cache.
Returns:
TMATensorTile
block_paged_ptr
block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[dtype_]]
Returns:
UnsafePointer
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!