Skip to main content

Mojo trait

KVCacheT

Trait for different KVCache types and implementations.

Represents a single (key or value) cache.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __copyinit__ is trivial.

The implementation of __copyinit__ is considered to be trivial if:

  • The struct has a compiler-generated trivial __copyinit__ and all its fields have a trivial __copyinit__ method.

In practice, it means the value can be copied by copying the bits from one location to another without side effects.

__del__is_trivial

comptime __del__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __del__ is trivial.

The implementation of __del__ is considered to be trivial if:

  • The struct has a compiler-generated trivial destructor and all its fields have a trivial __del__ method.

In practice, it means that the __del__ can be considered as no-op.

__moveinit__is_trivial

comptime __moveinit__is_trivial

A flag (often compiler generated) to indicate whether the implementation of __moveinit__ is trivial.

The implementation of __moveinit__ is considered to be trivial if:

  • The struct has a compiler-generated __moveinit__ and all its fields have a trivial __moveinit__ method.

In practice, it means the value can be moved by moving the bits from one location to another without side effects.

device_type

comptime device_type

Indicate the type being used on accelerator devices.

dtype

comptime dtype

kv_params

comptime kv_params

page_size_

comptime page_size_

quantization_enabled

comptime quantization_enabled = False

scale_dtype

comptime scale_dtype = DType.invalid

Required methods

__copyinit__

__copyinit__(out self: _Self, existing: _Self, /)

Create a new instance of the value by copying an existing one.

Args:

  • existing (_Self): The value to copy.

Returns:

_Self

__moveinit__

__moveinit__(out self: _Self, deinit existing: _Self, /)

Create a new instance of the value by moving the value of another.

Args:

  • existing (_Self): The value to move.

Returns:

_Self

cache_lengths_nd

cache_lengths_nd(self: _Self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]

Returns the cache lengths as a LayoutTensor.

Returns:

LayoutTensor

cache_length

cache_length(self: _Self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

load

load[width: Int, output_dtype: DType = _Self.dtype](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]

Loads an element from the given index.

Returns:

SIMD

store

store(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[_Self.dtype, size])

Stores an element at the given index.

store_scale

store_scale(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[_Self.scale_dtype, size])

Stores the quantization scales at the given index.

load_scale

load_scale[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]

Loads the quantization scales from the given index.

Returns:

SIMD

load_quantized

load_quantized[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.dtype, width]

Loads a quantized element from the given index.

Returns:

SIMD

empty_cache

empty_cache(self: _Self) -> Bool

Returns true if the cache_lengths for all requests is 0, false otherwise.

Returns:

Bool

max_prompt_length

max_prompt_length(self: _Self) -> UInt32

Returns the maximum sequence length across all batches of the current request.

Returns:

UInt32

max_context_length

max_context_length(self: _Self) -> UInt32

Returns the maximum cache length used across all batches of the current request.

Returns:

UInt32

block_paged_ptr

block_paged_ptr[tile_size: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.dtype], MutAnyOrigin]

Returns a LayoutTensor pointing to the KVCache block at the given index.

Paged KVCache implementations must have a block_size which is a multiple of the and greater than the layout's first dimension.

Returns:

UnsafePointer

scales_block_paged_ptr

scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

Returns a pointer to the scales block at the requested indices.

Returns:

UnsafePointer

max_tile_size

static max_tile_size() -> Int

Returns the maximum tile size for the KVCache.

Returns:

Int

row_idx

row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

create_tma_tile

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, Int.__init__[UInt](_Self.kv_params.head_size)]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, _split_last_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)]

Creates a TMA tile for this KV cache. This is useful for k-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile

create_ragged_tma_tile

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, Int.__init__[UInt](_Self.kv_params.head_size)]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BN, BK]

Creates a TMA tile for this KV cache. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.

Returns:

RaggedTMA3DTile

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

get_device_type_name

static get_device_type_name() -> String

Gets device_type's name. For example, because DeviceBuffer's device_type is UnsafePointer, DeviceBuffer[DType.float32]'s get_device_type_name() should return something like "UnsafePointer[Scalar[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The device type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self.

Returns:

_Self: A copy of this value.

Was this page helpful?