Skip to main content

Mojo struct

PagedKVCache

struct PagedKVCache[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int, scale_dtype_: DType = DType.invalid, quantization_granularity_: Int = 1]

The PagedKVCache is a wrapper around the KVCache blocks for a given layer. It is used to access the KVCache blocks for PagedAttention.

Note: This struct represents a 4D view of a 6D PagedKVCacheCollection tensor. The compile-time layout has UNKNOWN_VALUE for stride[0] because the actual stride depends on num_layers from the parent tensor, which is only known at runtime. This ensures offset calculations use the correct runtime strides rather than incorrect compile-time values.

Parameters​

  • ​dtype_ (DType): The dtype of the kv-cache.
  • ​kv_params_ (KVCacheStaticParams): The kv-cache static parameters.
  • ​page_size (Int): The size of the page.
  • ​scale_dtype_ (DType): Dtype of the quantization scales (if quantization enabled).
  • ​quantization_granularity_ (Int): Block size used for quantization (e.g. 128).

Fields​

  • ​blocks (PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_tt_type):
  • ​cache_lengths (PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].cache_lengths_tt_type):
  • ​lookup_table (PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].lookup_table_tt_type):
  • ​max_seq_length (UInt32):
  • ​max_cache_length (UInt32):
  • ​scales (OptionalReg[TileTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, Layout[*?, *?], MutAnyOrigin]]):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, KVCacheT, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

blocks_layout​

comptime blocks_layout = Layout(PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_shape, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].blocks_strides)

blocks_shape​

comptime blocks_shape = IntTuple(-1, page_size, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params)

blocks_strides​

comptime blocks_strides = IntTuple(-1, (PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params * PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params), PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params, 1)

blocks_tt_layout​

comptime blocks_tt_layout = Layout[*?, *?]

blocks_tt_type​

comptime blocks_tt_type = TileTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, Layout[*?, *?], MutAnyOrigin]

cache_lengths_tt_layout​

comptime cache_lengths_tt_layout = Layout[*?, *?]

cache_lengths_tt_type​

comptime cache_lengths_tt_type = TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]

device_type​

comptime device_type = PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_]

dtype​

comptime dtype = dtype_

head_dim_granularity​

comptime head_dim_granularity = ceildiv(PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.head_size, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].quantization_granularity)

kv_params​

comptime kv_params = kv_params_

lookup_table_tt_layout​

comptime lookup_table_tt_layout = Layout[*?, *?]

lookup_table_tt_type​

comptime lookup_table_tt_type = TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin]

page_size_​

comptime page_size_ = page_size

quantization_enabled​

comptime quantization_enabled = (scale_dtype_ != DType.invalid)

quantization_granularity​

comptime quantization_granularity = quantization_granularity_

scale_dtype​

comptime scale_dtype = scale_dtype_

scales_tt_layout​

comptime scales_tt_layout = Layout[*?, *?]

scales_tt_type​

comptime scales_tt_type = TileTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, Layout[*?, *?], MutAnyOrigin]

Methods​

__init__​

__init__(blocks: TileTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, Layout[*?, *?], MutAnyOrigin], cache_lengths: TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin], lookup_table: TileTensor[DType.uint32, Layout[*?, *?], ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[TileTensor[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, Layout[*?, *?], MutAnyOrigin]] = None) -> Self

get_type_name​

static get_type_name() -> String

Returns:

String

max_tile_size​

static max_tile_size() -> Int

Returns the maximum tile size for the KVCache.

Returns:

Int

cache_lengths_nd​

cache_lengths_nd(self) -> PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].cache_lengths_tt_type

Returns:

PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].cache_lengths_tt_type

cache_length​

cache_length(self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

get_tma_row​

get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

The encoded index is physical_block * page_size + offset. This method decomposes it and returns physical_block * stride + offset where stride is the distance (in rows) between consecutive physical blocks in the flattened memory view.

Returns:

Int32

num_kv_rows​

num_kv_rows(self) -> Int

Returns the total number of virtual rows in this KV cache view.

Returns:

Int

row_idx​

row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

populate​

populate[BN: Int, pair_cta: Bool = False, is_leader: Bool = True](self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].page_size_, pair_cta, is_leader]

SIMD LUT-load the num_pages block indices in one shot.

Computes `result.rows[i] = lookup_table[batch, first_lut_idx+i]

  • stride + tok_in_blockfor allnum_pagesentries using one (or a small fixed number of) alignedld.global.v{N}.u32` loads from the lookup table row.

Invariants:

  • self.lookup_table.dim[1] is large enough that a SIMD read of num_pages uint32s starting at any valid first_lut_idx stays in bounds (see PagedKVCacheManager for the allocation-side padding).
  • base_kv_row is BN-aligned for num_pages > 1 (every mask shipped with fa4/depth512/sm90 satisfies this). The first LUT index is then a multiple of num_pages, giving up to chunk * 4-byte alignment on the vector load. For num_pages == 1 the load is a scalar and alignment is irrelevant.

The per-load width chunk is the largest power of two that divides num_pages, capped at 8 β€” this keeps both the load width and the remaining chunk offsets aligned for the common num_pages in {1, 2, 4, 8, 16} cases and falls back to smaller widths when num_pages has a factor like 3 (e.g. BN=192, page_size=16 gives num_pages=12, chunk=4).

Returns:

PagedRowIndices[BN, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].page_size_, pair_cta, is_leader]

create_tma_tile​

create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, swizzle_mode, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.head_size]()](self, ctx: DeviceContext) -> TMATensorTile[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, 3, _padded_shape[3, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

Creates a TMA tile for this KV cache.

Returns:

TMATensorTile[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, 3, _padded_shape[3, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

create_gather4_tma_tile[*, tile_height: Int = 4, tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tma_dtype: DType = PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a 2D TMA gather4 descriptor for this KV cache.

The descriptor views the KV cache as a flat 2D matrix of [num_kv_rows, tile_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction. The box width is derived from the swizzle mode; for SWIZZLE_NONE it equals tile_width.

When tma_dtype differs from Self.dtype, the underlying data pointer is bitcast to tma_dtype at descriptor creation time.

Parameters:

  • ​tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • ​tile_width (Int): Number of elements per row to load (box width) in tma_dtype elements.
  • ​tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the global row is wider than the portion to load.
  • ​swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.
  • ​tma_dtype (DType): The data type used for the TMA descriptor. Defaults to Self.dtype. When different, the pointer is bitcast.
  • ​l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ​ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A TMATensorTile with box width derived from the swizzle mode.

create_ragged_tma_tile​

create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, swizzle_mode, PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].kv_params.head_size]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, swizzle_mode, BM=BN, BN=BK])

Returns:

RaggedTMA3DTile[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Creates a BF16 TMA tile for the rope portion of the per-tensor rope-aware KV cache.

In the per-tensor rope-aware layout each token row is: padded_depth FP8 bytes (content) | BK BF16 elements (rope) Total row bytes = padded_depth + BK * 2.

The TMA descriptor points at the rope data by offsetting blocks.ptr by padded_depth bytes, then reinterpreting as BF16. The global memory stride dimension (last dim of gmem_shape) is the total row size expressed in BF16 units: (padded_depth + BK * 2) // 2.

Returns:

TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_rope_gather4_tma_tile​

create_rope_gather4_tma_tile[*, tile_height: Int = 4, tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a BF16 gather4 TMA descriptor for the rope portion of the KV cache.

For the per-tensor rope-aware layout each token row is stored as padded_depth FP8 bytes (content) followed by BF16 rope elements. The total row width in BF16 units is (padded_depth + tile_width * 2) // 2.

This method offsets blocks.ptr by padded_depth bytes, reinterprets as BF16, and creates a gather4 TMA descriptor whose row stride is the full row width in BF16 elements.

Returns:

TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

load​

load[width: Int, output_dtype: DType = PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]

Loads an element from the given index.

Returns:

SIMD[output_dtype, width]

store​

store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype])

Stores an element at the given index.

load_scale​

load_scale[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, width]

Loads a quantization scale from the given index.

Returns:

SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype, width]

store_scale​

store_scale(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype])

Stores the quantization scales at the given index.

load_quantized​

load_quantized[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, width]

Loads a quantized element from the given index.

Returns:

SIMD[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype, width]

empty_cache​

empty_cache(self) -> Bool

Returns true if the cache_lengths for all requests is 0, false otherwise.

Returns:

Bool

max_prompt_length​

max_prompt_length(self) -> UInt32

Returns the maximum sequence length across all batches of the current request.

Returns:

UInt32

max_context_length​

max_context_length(self) -> UInt32

Returns the maximum cache length used across all batches of the current request.

Returns:

UInt32

block_paged_ptr​

block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype], MutAnyOrigin]

Returns:

UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].dtype], MutAnyOrigin]

scales_block_paged_ptr​

scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype], MutAnyOrigin]

Returns a pointer to the scales block at the requested indices.

Returns:

UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype], MutAnyOrigin]

scales_raw_ptr​

scales_raw_ptr(self) -> UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype], MutAnyOrigin]

Returns the base pointer to the scales tensor, or a dangling pointer if scales are not set.

Returns:

UnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size, scale_dtype_, quantization_granularity_].scale_dtype], MutAnyOrigin]