Skip to main content

Mojo trait

KVCacheT

Trait for different KVCache types and implementations.

Represents a single (key or value) cache.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

device_type

comptime device_type

Indicate the type being used on accelerator devices.

dtype

comptime dtype

kv_params

comptime kv_params

page_size_

comptime page_size_

quantization_enabled

comptime quantization_enabled = False

quantization_granularity

comptime quantization_granularity = 1

scale_dtype

comptime scale_dtype = DType.invalid

Required methods

__init__

__init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • copy (_Self): The value to copy.

Returns:

_Self

__init__(out self: _Self, *, deinit take: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • take (_Self): The value to move.

Returns:

_Self

cache_lengths_nd

cache_lengths_nd(self: _Self) -> TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]

Returns the cache lengths as a TileTensor.

Returns:

TileTensor

cache_length

cache_length(self: _Self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

load

load[width: Int, output_dtype: DType = _Self.dtype](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]

Loads an element from the given index.

Returns:

SIMD

store

store(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[_Self.dtype, val.size])

Stores an element at the given index.

store_scale

store_scale(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[_Self.scale_dtype, scales.size])

Stores the quantization scales at the given index.

load_scale

load_scale[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]

Loads the quantization scales from the given index.

Returns:

SIMD

load_quantized

load_quantized[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.dtype, width]

Loads a quantized element from the given index.

Returns:

SIMD

empty_cache

empty_cache(self: _Self) -> Bool

Returns true if the cache_lengths for all requests is 0, false otherwise.

Returns:

Bool

max_prompt_length

max_prompt_length(self: _Self) -> UInt32

Returns the maximum sequence length across all batches of the current request.

Returns:

UInt32

max_context_length

max_context_length(self: _Self) -> UInt32

Returns the maximum cache length used across all batches of the current request.

Returns:

UInt32

block_paged_ptr

block_paged_ptr[tile_size: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.dtype], MutAnyOrigin]

Returns a pointer to the KVCache block at the given index.

Paged KVCache implementations must have a block_size which is a multiple of the and greater than the layout's first dimension.

Returns:

UnsafePointer

scales_block_paged_ptr

scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

Returns a pointer to the scales block at the requested indices.

Returns:

UnsafePointer

scales_raw_ptr

scales_raw_ptr(self: _Self) -> UnsafePointer[Scalar[_Self.scale_dtype], MutAnyOrigin]

Returns the base pointer to the scales tensor.

For PagedKVCache with quantization enabled, this returns the raw base pointer of the scales TileTensor. For caches without quantization, returns a null pointer.

Returns:

UnsafePointer

max_tile_size

static max_tile_size() -> Int

Returns the maximum tile size for the KVCache.

Returns:

Int

num_kv_rows

num_kv_rows(self: _Self) -> Int

Returns the total number of virtual rows in this KV cache view.

For paged caches this accounts for the paging stride: (total_blocks - 1) * stride + page_size.

Returns:

Int

row_idx

row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

get_tma_row

get_tma_row(self: _Self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

For paged caches the encoded index is physical_block * page_size + offset and this method returns physical_block * stride + offset. Non-paged caches return the encoded index unchanged.

Returns:

Int32

create_tma_tile

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, _Self.kv_params.head_size]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode]()]

Creates a TMA tile for this KV cache. This is useful for k-major MMA operations where we don't need to mask any extra rows.

Returns:

TMATensorTile

create_ragged_tma_tile

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, _Self.kv_params.head_size]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BM=BN, BN=BK]

Creates a TMA tile for this KV cache. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.

Returns:

RaggedTMA3DTile

create_rope_tma_tile

create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode]()]

Creates a BF16 TMA tile for the rope portion of the KV cache.

For the per-tensor rope-aware layout, each token row in the KV cache is stored as padded_depth FP8 bytes (content) followed by BK BF16 elements (rope). This method returns a TMA descriptor that points at the rope data starting at byte offset padded_depth within each row, reinterpreted as BF16.

Returns:

TMATensorTile

create_gather4_tma_tile

create_gather4_tma_tile[*, tile_height: Int = 4, tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tma_dtype: DType = _Self.dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=Tuple()), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=Tuple())]

Creates a 2D TMA gather4 descriptor for this KV cache.

The descriptor views the KV cache as a flat 2D matrix of [num_kv_rows, tile_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction. The box width is derived from the swizzle mode; for SWIZZLE_NONE it equals tile_width.

The tile_height parameter records the full tile height (e.g. 64 rows) in the returned TMATensorTile.tile_shape. The hardware descriptor shape stays (1, box_width) as required by TMA gather4.

When tma_dtype differs from Self.dtype, the underlying data pointer is bitcast to tma_dtype at descriptor creation time. This allows, for example, creating an INT64/SWIZZLE_NONE descriptor over FP8 data for linear SMEM layout.

Parameters:

  • tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • tile_width (Int): Number of elements per row to load (box width) in tma_dtype elements.
  • tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the global row is wider than the portion to load.
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.
  • tma_dtype (DType): The data type used for the TMA descriptor. Defaults to Self.dtype. When different, the pointer is bitcast.
  • l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile: A TMATensorTile with box width derived from the swizzle mode.

create_rope_gather4_tma_tile

create_rope_gather4_tma_tile[*, tile_height: Int = 4, tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=Tuple()), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=Tuple())]

Creates a BF16 gather4 TMA descriptor for the rope portion of the KV cache.

For the per-tensor rope-aware layout each token row is stored as padded_depth FP8 bytes (content) followed by BF16 rope elements. This method offsets the base pointer by padded_depth bytes, reinterprets as BF16, and creates a gather4 TMA descriptor with tile_width BF16 elements per row.

Parameters:

  • tile_height (Int): Number of rows in the tile. Must be a multiple of 4.
  • tile_width (Int): Number of BF16 elements per row in global memory.
  • padded_depth (Int): Byte offset from row start to the rope data.
  • swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern.
  • l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile: A BF16 TMATensorTile configured for gather4.

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods

copy

copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Returns:

_Self: A copy of this value.

Was this page helpful?