IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

PagedKVCache

struct PagedKVCache[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int, blocks_origin: MutOrigin, cache_lengths_origin: ImmutOrigin, lookup_table_origin: ImmutOrigin, scales_origin: MutOrigin, *, scale_dtype_: DType = DType.invalid, quantization_granularity_: Int = Int(1)]

The PagedKVCache is a wrapper around the KVCache blocks for a given layer. It is used to access the KVCache blocks for PagedAttention.

Note: This struct represents a 4D view of a 6D PagedKVCacheCollection tensor. The compile-time layout has UNKNOWN_VALUE for stride[0] because the actual stride depends on num_layers from the parent tensor, which is only known at runtime. This ensures offset calculations use the correct runtime strides rather than incorrect compile-time values.

Parameters​

  • ​dtype_ (DType): The dtype of the kv-cache.
  • ​kv_params_ (KVCacheStaticParams): The kv-cache static parameters.
  • ​page_size (Int): The size of the page.
  • ​blocks_origin (MutOrigin): Origin of the KV cache blocks buffer.
  • ​cache_lengths_origin (ImmutOrigin): Origin of the cache lengths buffer.
  • ​lookup_table_origin (ImmutOrigin): Origin of the lookup table buffer.
  • ​scales_origin (MutOrigin): Origin of the quantization scales buffer.
  • ​scale_dtype_ (DType): Dtype of the quantization scales (if quantization enabled).
  • ​quantization_granularity_ (Int): Block size used for quantization (e.g. 128).

Fields​

  • ​blocks (PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].blocks_tt_type):
  • ​cache_lengths (PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].cache_lengths_tt_type):
  • ​lookup_table (PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].lookup_table_tt_type):
  • ​max_seq_length (UInt32):
  • ​max_cache_length (UInt32):
  • ​scales (OptionalReg[TileTensor[PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].scale_dtype, Layout[*?, *?], scales_origin]]):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, KVCacheT, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

blocks_layout​

comptime blocks_layout = Layout(PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].blocks_shape, PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].blocks_strides)

blocks_shape​

comptime blocks_shape = IntTuple(Int(-1), page_size, kv_params_, kv_params_)

blocks_strides​

comptime blocks_strides = IntTuple(Int(-1), Int((mul kv_params_.head_size, kv_params_.num_heads)), kv_params_, Int(1))

blocks_tt_layout​

comptime blocks_tt_layout = Layout[*?, *?]

blocks_tt_type​

comptime blocks_tt_type = TileTensor[PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].dtype, Layout[*?, *?], blocks_origin]

cache_lengths_tt_layout​

comptime cache_lengths_tt_layout = Layout[*?, *?]

cache_lengths_tt_type​

comptime cache_lengths_tt_type = TileTensor[DType.uint32, Layout[*?, *?], cache_lengths_origin]

device_type​

comptime device_type = PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_]

dtype​

comptime dtype = dtype_

head_dim_granularity​

comptime head_dim_granularity = ceildiv(kv_params_.head_size, PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].quantization_granularity)

kv_params​

comptime kv_params = kv_params_

lookup_table_tt_layout​

comptime lookup_table_tt_layout = Layout[*?, *?]

lookup_table_tt_type​

comptime lookup_table_tt_type = TileTensor[DType.uint32, Layout[*?, *?], lookup_table_origin]

page_size_​

comptime page_size_ = page_size

quantization_enabled​

comptime quantization_enabled = (scale_dtype_ != DType.invalid)

quantization_granularity​

comptime quantization_granularity = quantization_granularity_

scale_dtype​

comptime scale_dtype = scale_dtype_

scales_tt_layout​

comptime scales_tt_layout = Layout[*?, *?]

scales_tt_type​

comptime scales_tt_type = TileTensor[PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].scale_dtype, Layout[*?, *?], scales_origin]

Methods​

__init__​

def __init__(blocks: TileTensor[Self.dtype, Layout[*?, *?], blocks_origin], cache_lengths: TileTensor[DType.uint32, Layout[*?, *?], cache_lengths_origin], lookup_table: TileTensor[DType.uint32, Layout[*?, *?], lookup_table_origin], max_seq_length: UInt32, max_cache_length: UInt32, scales: OptionalReg[TileTensor[Self.scale_dtype, Layout[*?, *?], scales_origin]] = None) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

max_tile_size​

static def max_tile_size() -> Int

Returns the maximum tile size for the KVCache.

Returns:

Int

cache_lengths_nd​

def cache_lengths_nd(self) -> Self.cache_lengths_tt_type

Returns:

Self.cache_lengths_tt_type

cache_length​

def cache_length(self, batch_idx: Int) -> Int

Returns the length of the cache for a given batch index.

Returns:

Int

get_tma_row​

def get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

The encoded index is physical_block * page_size + offset. This method decomposes it and returns physical_block * stride + offset where stride is the distance (in rows) between consecutive physical blocks in the flattened memory view.

Returns:

Int32

num_kv_rows​

def num_kv_rows(self) -> Int

Returns the total number of virtual rows in this KV cache view.

Returns:

Int

row_idx​

def row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

populate​

def populate[BN: Int, base_alignment: Int, pair_cta: Bool = False, is_leader: Bool = True](self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, Self.page_size_, pair_cta, is_leader]

SIMD LUT-load the num_pages block indices in one shot.

Computes `result.rows[i] = lookup_table[batch, first_lut_idx+i]

  • stride + tok_in_blockfor allnum_pagesentries using one (or a small fixed number of) alignedld.global.v{N}.u32` loads from the lookup table row.

Invariants:

  • self.lookup_table.dim[1] is large enough that a SIMD read of num_pages uint32s starting at any valid first_lut_idx stays in bounds (see PagedKVCacheManager for the allocation-side padding).
  • base_kv_row % base_alignment == 0 holds at runtime (typically mask.start_column_alignment[...]()). For num_pages > 1, base_alignment must be at least page_size β€” required so tok_in_block_idx == 0 and the SIMD multiply-add collapses to a multiply. Larger base_alignment values let us pick a wider SIMD chunk (chunk * page_size must divide base_alignment).

The per-load width chunk is the largest power of two that divides both num_pages and base_alignment / page_size, capped at 8. With base_alignment == BN (the historical contract), this matches the previous behaviour: chunk = min(num_pages & -num_pages, 8). With looser alignments (e.g. ChunkedMask providing only page_size alignment when BN > page_size), the chunk degrades to 1 (scalar loads).

Returns:

PagedRowIndices[BN, Self.page_size_, pair_cta, is_leader]

create_tma_tile​

def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[dtype_, swizzle_mode, kv_params_.head_size]()](self, ctx: DeviceContext) -> TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

Creates a TMA tile for this KV cache.

Returns:

TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

def create_gather4_tma_tile[*, tile_height: Int = Int(4), tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tma_dtype: DType = PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a 2D TMA gather4 descriptor for this KV cache.

The descriptor views the KV cache as a flat 2D matrix of [num_kv_rows, tile_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction. The box width is derived from the swizzle mode; for SWIZZLE_NONE it equals tile_width.

When tma_dtype differs from Self.dtype, the underlying data pointer is bitcast to tma_dtype at descriptor creation time.

Parameters:

  • ​tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • ​tile_width (Int): Number of elements per row to load (box width) in tma_dtype elements.
  • ​tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the global row is wider than the portion to load.
  • ​swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.
  • ​tma_dtype (DType): The data type used for the TMA descriptor. Defaults to Self.dtype. When different, the pointer is bitcast.
  • ​l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ​ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A TMATensorTile with box width derived from the swizzle mode.

create_ragged_tma_tile​

def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[dtype_, swizzle_mode, kv_params_.head_size]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])

Returns:

RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Creates a BF16 TMA tile for the rope portion of the per-tensor rope-aware KV cache.

In the per-tensor rope-aware layout each token row is: padded_depth FP8 bytes (content) | BK BF16 elements (rope) Total row bytes = padded_depth + BK * 2.

The TMA descriptor points at the rope data by offsetting blocks.ptr by padded_depth bytes, then reinterpreting as BF16. The global memory stride dimension (last dim of gmem_shape) is the total row size expressed in BF16 units: (padded_depth + BK * 2) // 2.

Returns:

TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_rope_gather4_tma_tile​

def create_rope_gather4_tma_tile[*, tile_height: Int = Int(4), tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a BF16 gather4 TMA descriptor for the rope portion of the KV cache.

For the per-tensor rope-aware layout each token row is stored as padded_depth FP8 bytes (content) followed by BF16 rope elements. The total row width in BF16 units is (padded_depth + tile_width * 2) // 2.

This method offsets blocks.ptr by padded_depth bytes, reinterprets as BF16, and creates a gather4 TMA descriptor whose row stride is the full row width in BF16 elements.

Returns:

TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

load​

def load[width: Int, output_dtype: DType = PagedKVCache[dtype_, kv_params_, page_size, blocks_origin, cache_lengths_origin, lookup_table_origin, scales_origin, scale_dtype_=scale_dtype_, quantization_granularity_=quantization_granularity_].dtype](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]

Loads an element from the given index.

Returns:

SIMD[output_dtype, width]

store​

def store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[Self.dtype])

Stores an element at the given index.

Skips the write when the LUT entry for (bs, tok_idx // page_size) is the unassigned-slot sentinel β€” i.e. when the resolved block_idx is outside [0, total_num_blocks). The cache manager fills LUT columns past a request's allocated block count with the sentinel value total_num_pages (see cache_manager.py's lut_table_np.fill(self._total_num_pages)) so that SIMD over-reads of the LUT row are safe, but the value of the sentinel times the page stride lands one page past the end of the cache buffer. Without this guard a sentinel-resolved store corrupts whatever device allocation happens to sit immediately after the KV cache.

load_scale​

def load_scale[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[Self.scale_dtype, width]

Loads a quantization scale from the given index.

Returns:

SIMD[Self.scale_dtype, width]

store_scale​

def store_scale(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[Self.scale_dtype])

Stores the quantization scales at the given index.

load_quantized​

def load_quantized[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[Self.dtype, width]

Loads a quantized element from the given index.

Returns:

SIMD[Self.dtype, width]

empty_cache​

def empty_cache(self) -> Bool

Returns true if the cache_lengths for all requests is 0, false otherwise.

Returns:

Bool

max_prompt_length​

def max_prompt_length(self) -> UInt32

Returns the maximum sequence length across all batches of the current request.

Returns:

UInt32

max_context_length​

def max_context_length(self) -> UInt32

Returns the maximum cache length used across all batches of the current request.

Returns:

UInt32

block_paged_ptr​

def block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[Self.dtype], MutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.dtype], MutAnyOrigin]

scales_block_paged_ptr​

def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[Self.scale_dtype], MutAnyOrigin]

Returns a pointer to the scales block at the requested indices.

Returns:

UnsafePointer[Scalar[Self.scale_dtype], MutAnyOrigin]

scales_raw_ptr​

def scales_raw_ptr(self) -> UnsafePointer[Scalar[Self.scale_dtype], MutAnyOrigin]

Returns the base pointer to the scales tensor, or a dangling pointer if scales are not set.

Returns:

UnsafePointer[Scalar[Self.scale_dtype], MutAnyOrigin]