Mojo struct
ContinuousBatchingKVCache
@register_passable(trivial)
struct ContinuousBatchingKVCache[dtype_: DType, kv_params_: KVCacheStaticParams]
Wrapper for the ContinuousKVCache of a given layer in the transformer model.
This abstracts the Pointer indirection for accessing the ContinuousKVCache for a given batch entry.
THIS IS THE TYPE THAT IS PASSED TO KV PROJECTION AND FLASH ATTENTION KERNELS.
Parameters
- dtype_ (
DType): The dtype of the kv-cache. - kv_params_ (
KVCacheStaticParams): The kv-cache static parameters.
Fields
- blocks (
LayoutTensor[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_layout, MutAnyOrigin]): - cache_lengths (
LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]): - lookup_table (
LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]): - max_seq_length (
UInt32): - max_cache_length (
UInt32):
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
KVCacheT,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
blocks_layout
comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCache[dtype_, kv_params_].blocks_shape)
blocks_shape
comptime blocks_shape = IntTuple(-1, -1, Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params), Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params))
blocks_type
comptime blocks_type = LayoutTensor[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_layout, MutAnyOrigin]
device_type
comptime device_type = ContinuousBatchingKVCache[dtype_, kv_params_]
dtype
comptime dtype = dtype_
kv_params
comptime kv_params = kv_params_
page_size_
comptime page_size_ = 0
Methods
__init__
__init__(blocks: LayoutTensor[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_layout, MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32) -> Self
get_type_name
static get_type_name() -> String
Returns:
String
get_device_type_name
static get_device_type_name() -> String
Returns:
String
max_tile_size
cache_lengths_nd
cache_lengths_nd(self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]
Returns:
cache_length
load
load[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, width]
Returns:
store
store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, size])
empty_cache
empty_cache(self) -> Bool
Returns true if the cache_lengths for all requests is 0, false otherwise.
Returns:
max_prompt_length
max_prompt_length(self) -> UInt32
Returns the maximum sequence length across all batches of the current request.
Returns:
max_context_length
max_context_length(self) -> UInt32
Returns the maximum cache length used across all batches of the current request.
Returns:
row_idx
row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
col_idx
col_idx(self, head_idx: UInt32) -> UInt32
Returns the col idx when viewing the memory as a matrix.
Returns:
create_tma_tile
create_tma_tile[tile_m: Int, tile_n: Int, swizzle_mode: TensorMapSwizzle, *, is_k_major: Bool](self, ctx: DeviceContext) -> TMATensorTile[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, tile_layout_k_major[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, tile_m, tile_n, swizzle_mode]() if is_k_major else tile_layout_mn_major[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, tile_n, tile_m, swizzle_mode](), _tma_desc_tile_layout[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, 2, IndexList[2, DType.int64](tile_m, tile_n, Tuple[]()), is_k_major, swizzle_mode](), is_k_major]
Creates a TMA tile for this KV cache.
Returns:
TMATensorTile
block_paged_ptr
block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> LegacyUnsafePointer[Scalar[ContinuousBatchingKVCache[dtype_, kv_params_].dtype]]
Returns:
LegacyUnsafePointer
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!