Skip to main content

Mojo struct

PagedKVCacheCollection

struct PagedKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int, scale_dtype_: DType = DType.invalid, quantization_granularity_: Int = 1]

Fields

  • scales (OptionalReg[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_tt_type]):
  • kv_cache_scales_dynamic_shape (IndexList[4]):
  • kv_cache_scales_dynamic_strides (IndexList[4]):
  • blocks (PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_tt_type):
  • cache_lengths (PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType.cache_lengths_tt_type):
  • lookup_table (PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType.lookup_table_tt_type):
  • max_seq_length (UInt32):
  • max_cache_length (UInt32):
  • kv_cache_dynamic_shape (IndexList[4]):
  • kv_cache_dynamic_strides (IndexList[4]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, KVCollectionT, Movable

comptime members

blocks_layout

comptime blocks_layout = Layout.row_major(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_shape)

blocks_shape

comptime blocks_shape = IntTuple(-1, 2 if not PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.is_mla.__bool__() else 1, -1, page_size, Int[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.num_heads), Int[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.head_size))

blocks_tt_layout

comptime blocks_tt_layout = Layout[#kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_layout.shape), [idx: __mlir_type.index] _int_to_dim(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_layout.shape[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64])), #kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_layout.stride), [idx: __mlir_type.index] _int_to_dim(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_layout.stride[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64]))]

blocks_tt_type

comptime blocks_tt_type = TileTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_tt_layout, MutAnyOrigin]

CacheType

comptime CacheType = PagedKVCache[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, page_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, quantization_granularity_]

dtype

comptime dtype = dtype_

head_dim_granularity

comptime head_dim_granularity = ceildiv(Int[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.head_size), PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType.quantization_granularity)

kv_params

comptime kv_params = kv_params_

name_str

comptime name_str = "paged"

scale_dtype

comptime scale_dtype = scale_dtype_

scales_layout

comptime scales_layout = Layout.row_major(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_shape)

scales_shape

comptime scales_shape = IntTuple(-1, 2 if not PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.is_mla.__bool__() else 1, -1, page_size, Int[UInt](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.num_heads), PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].head_dim_granularity)

scales_tt_layout

comptime scales_tt_layout = Layout[#kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_layout.shape), [idx: __mlir_type.index] _int_to_dim(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_layout.shape[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64])), #kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_layout.stride), [idx: __mlir_type.index] _int_to_dim(PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_layout.stride[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64]))]

scales_tt_type

comptime scales_tt_type = TileTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_tt_layout, MutAnyOrigin]

Methods

__init__

__init__(out self, blocks: LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, Layout.row_major[6](), MutAnyOrigin]] = None)

Construct from LayoutTensor params (MOGG boundary).

__init__(out self, blocks: TileTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_tt_layout, MutAnyOrigin], cache_lengths: TileTensor[DType.uint32, PagedKVCache[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, page_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, quantization_granularity_].cache_lengths_tt_layout, ImmutAnyOrigin], lookup_table: TileTensor[DType.uint32, PagedKVCache[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, page_size, PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, quantization_granularity_].lookup_table_tt_layout, ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scales_tt_type] = None)

Construct from TileTensor fields directly.

get_key_cache

get_key_cache(self, layer_idx: Int) -> PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType

Returns:

PagedKVCacheCollection

get_value_cache

get_value_cache(self, layer_idx: Int) -> PagedKVCacheCollection[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].CacheType

Returns:

PagedKVCacheCollection

cache_length

cache_length(self, bs_idx: Int) -> Int

Returns:

Int

Was this page helpful?