IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

PagedKVCacheCollection

struct PagedKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int, scale_dtype_: DType = DType.invalid, quantization_granularity_: Int = 1]

Fields​

  • ​scales (OptionalReg[TileTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, Layout[*?, *?], MutAnyOrigin]]):
  • ​kv_cache_scales_dynamic_shape (IndexList[4]):
  • ​kv_cache_scales_dynamic_strides (IndexList[4]):
  • ​blocks (PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_tt_type):
  • ​cache_lengths (PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType.cache_lengths_tt_type):
  • ​lookup_table (PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType.lookup_table_tt_type):
  • ​max_seq_length (UInt32):
  • ​max_cache_length (UInt32):
  • ​kv_cache_dynamic_shape (IndexList[4]):
  • ​kv_cache_dynamic_strides (IndexList[4]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, KVCollectionT, Movable

comptime members​

blocks_layout​

comptime blocks_layout = Layout.row_major(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_shape)

blocks_shape​

comptime blocks_shape = IntTuple(-1, 2 if not PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.is_mla.__bool__() else 1, -1, page_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params)

blocks_tt_layout​

comptime blocks_tt_layout = Layout[*?, *?]

blocks_tt_type​

comptime blocks_tt_type = TileTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, Layout[*?, *?], MutAnyOrigin]

CacheType​

comptime CacheType = PagedKVCache[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, page_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, quantization_granularity_]

dtype​

comptime dtype = dtype_

head_dim_granularity​

comptime head_dim_granularity = ceildiv(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.head_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType.quantization_granularity)

kv_params​

comptime kv_params = kv_params_

name_str​

comptime name_str = "paged"

scale_dtype​

comptime scale_dtype = scale_dtype_

scales_layout​

comptime scales_layout = Layout.row_major(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_shape)

scales_shape​

comptime scales_shape = IntTuple(-1, 2 if not PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.is_mla.__bool__() else 1, -1, page_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].head_dim_granularity)

scales_tt_layout​

comptime scales_tt_layout = Layout[*?, *?]

scales_tt_type​

comptime scales_tt_type = TileTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, Layout[*?, *?], MutAnyOrigin]

Methods​

__init__​

def __init__[scales_origin: MutOrigin, //](out self, blocks: LayoutTensor[Self.dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[LayoutTensor[Self.scale_dtype, Layout.row_major[6](), scales_origin]] = OptionalReg())

Construct from LayoutTensor params (MOGG boundary).

def __init__(out self, blocks: TileTensor[Self.dtype, Layout[*?, *?], MutAnyOrigin], cache_lengths: TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin], lookup_table: TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[TileTensor[Self.scale_dtype, Layout[*?, *?], MutAnyOrigin]] = None)

Construct from TileTensor fields directly.

get_key_cache​

def get_key_cache(self, layer_idx: Int) -> Self.CacheType

Returns:

Self.CacheType

get_value_cache​

def get_value_cache(self, layer_idx: Int) -> Self.CacheType

Returns:

Self.CacheType

cache_length​

def cache_length(self, bs_idx: Int) -> Int

Returns:

Int