Skip to main content

Mojo trait

KVCacheT

Trait for different KVCache types and implementations.

Represents a single (key or value) cache.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

device_type

comptime device_type

Indicate the type being used on accelerator devices.

dtype

comptime dtype

kv_params

comptime kv_params

page_size_

comptime page_size_

quantization_enabled

comptime quantization_enabled = False

quantization_granularity

comptime quantization_granularity = 1

scale_dtype

comptime scale_dtype = DType.invalid

Required methods

__init__

__init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • copy (_Self): The value to copy.

Returns:

_Self

__init__(out self: _Self, *, deinit take: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • take (_Self): The value to move.

Returns:

_Self

cache_lengths_nd

cache_lengths_nd(self: _Self) -> TileTensor[DType.uint32, Layout[RuntimeInt[DType.int64], ComptimeInt[1]], ImmutAnyOrigin]

Returns the cache lengths as a TileTensor.

Returns:

TileTensor

cache_length

cache_length(self: _Self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

load

load[width: Int, output_dtype: DType = _Self.dtype](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]

Loads an element from the given index.

Returns:

SIMD

store

store(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[_Self.dtype, val.size])

Stores an element at the given index.

store_scale

store_scale(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[_Self.scale_dtype, scales.size])

Stores the quantization scales at the given index.

load_scale

load_scale[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]

Loads the quantization scales from the given index.

Returns:

SIMD

load_quantized

load_quantized[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.dtype, width]

Loads a quantized element from the given index.

Returns:

SIMD

empty_cache

empty_cache(self: _Self) -> Bool

Returns true if the cache_lengths for all requests is 0, false otherwise.

Returns:

Bool

max_prompt_length

max_prompt_length(self: _Self) -> UInt32

Returns the maximum sequence length across all batches of the current request.

Returns:

UInt32

max_context_length

max_context_length(self: _Self) -> UInt32

Returns the maximum cache length used across all batches of the current request.

Returns:

UInt32

block_paged_ptr

block_paged_ptr[tile_size: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.dtype], MutAnyOrigin]

Returns a pointer to the KVCache block at the given index.

Paged KVCache implementations must have a block_size which is a multiple of the and greater than the layout's first dimension.

Returns:

UnsafePointer

scales_block_paged_ptr

scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

Returns a pointer to the scales block at the requested indices.

Returns:

UnsafePointer

scales_raw_ptr

scales_raw_ptr(self: _Self) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

Returns the base pointer to the scales tensor.

For PagedKVCache with quantization enabled, this returns the raw base pointer of the scales TileTensor. For caches without quantization, returns a null pointer.

Returns:

UnsafePointer

max_tile_size

static max_tile_size() -> Int

Returns the maximum tile size for the KVCache.

Returns:

Int

num_kv_rows

num_kv_rows(self: _Self) -> Int

Returns the total number of virtual rows in this KV cache view.

For paged caches this accounts for the paging stride: (total_blocks - 1) * stride + page_size.

Returns:

Int

row_idx

row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

create_tma_tile

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, Int[UInt](_Self.kv_params.head_size)]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode]()]

Creates a TMA tile for this KV cache. This is useful for k-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile

create_ragged_tma_tile

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, Int[UInt](_Self.kv_params.head_size)]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BN, BK]

Creates a TMA tile for this KV cache. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.

Returns:

RaggedTMA3DTile

create_rope_tma_tile

create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode]()]

Creates a BF16 TMA tile for the rope portion of the KV cache.

For the per-tensor rope-aware layout, each token row in the KV cache is stored as padded_depth FP8 bytes (content) followed by BK BF16 elements (rope). This method returns a TMA descriptor that points at the rope data starting at byte offset padded_depth within each row, reinterpreted as BF16.

Returns:

TMATensorTile

create_gather4_tma_tile

create_gather4_tma_tile[row_width: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 2, IndexList(4, row_width, __list_literal__=Tuple()), IndexList(1, row_width, __list_literal__=Tuple())]

Creates a 2D TMA gather4 descriptor for this KV cache.

The descriptor views the KV cache as a flat 2D matrix of [num_kv_rows, row_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction.

Parameters:

  • row_width (Int): Number of elements per row (innermost dimension).
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.

Args:

  • ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile: A TMATensorTile with tile_shape=(4, row_width) and desc_shape=(1, row_width).

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Returns:

_Self: A copy of this value.

Was this page helpful?