IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

ContinuousBatchingKVCache

struct ContinuousBatchingKVCache[dtype_: DType, kv_params_: KVCacheStaticParams, blocks_origin: MutOrigin, cache_lengths_origin: ImmutOrigin, lookup_table_origin: ImmutOrigin]

Wrapper for the ContinuousKVCache of a given layer in the transformer model.

This abstracts the Pointer indirection for accessing the ContinuousKVCache for a given batch entry.

THIS IS THE TYPE THAT IS PASSED TO KV PROJECTION AND FLASH ATTENTION KERNELS.

Parameters​

  • ​dtype_ (DType): The dtype of the kv-cache.
  • ​kv_params_ (KVCacheStaticParams): The kv-cache static parameters.
  • ​blocks_origin (MutOrigin): Origin of the KV cache blocks buffer.
  • ​cache_lengths_origin (ImmutOrigin): Origin of the cache lengths buffer.
  • ​lookup_table_origin (ImmutOrigin): Origin of the lookup table buffer.

Fields​

  • ​blocks (ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin].blocks_tt_type):
  • ​cache_lengths (ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin].cache_lengths_tt_type):
  • ​lookup_table (ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin].lookup_table_tt_type):
  • ​max_seq_length (UInt32):
  • ​max_cache_length (UInt32):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, KVCacheT, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

blocks_layout​

comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin].blocks_shape)

blocks_shape​

comptime blocks_shape = IntTuple(Int(-1), Int(-1), kv_params_, kv_params_)

blocks_tt_layout​

comptime blocks_tt_layout = Layout[*?, *?]

blocks_tt_type​

comptime blocks_tt_type = TileTensor[ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin].dtype, Layout[*?, *?], blocks_origin]

cache_lengths_tt_layout​

comptime cache_lengths_tt_layout = Layout[*?, *?]

cache_lengths_tt_type​

comptime cache_lengths_tt_type = TileTensor[DType.uint32, Layout[*?, *?], cache_lengths_origin]

device_type​

comptime device_type = ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin]

dtype​

comptime dtype = dtype_

kv_params​

comptime kv_params = kv_params_

lookup_table_tt_layout​

comptime lookup_table_tt_layout = Layout[*?, *?]

lookup_table_tt_type​

comptime lookup_table_tt_type = TileTensor[DType.uint32, Layout[*?, *?], lookup_table_origin]

page_size_​

comptime page_size_ = 0

quantization_enabled​

comptime quantization_enabled = False

quantization_granularity​

comptime quantization_granularity = 1

scale_dtype​

comptime scale_dtype = DType.float32

Methods​

__init__​

def __init__(blocks: TileTensor[Self.dtype, Layout[*?, *?], blocks_origin], cache_lengths: TileTensor[DType.uint32, Layout[*?, *?], cache_lengths_origin], lookup_table: TileTensor[DType.uint32, Layout[*?, *?], lookup_table_origin], max_seq_length: UInt32, max_cache_length: UInt32) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

max_tile_size​

static def max_tile_size() -> Int

Returns the maximum tile size for the KVCache.

Returns:

Int

cache_lengths_nd​

def cache_lengths_nd(self) -> Self.cache_lengths_tt_type

Returns:

Self.cache_lengths_tt_type

cache_length​

def cache_length(self, batch_idx: Int) -> Int

Returns:

Int

load​

def load[width: Int, output_dtype: DType = ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin].dtype](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[output_dtype, width]

Returns:

SIMD[output_dtype, width]

store​

def store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[Self.dtype])

load_scale​

def load_scale[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[DType.float32, width]

Loads a quantization scale from the given index.

Note: ContinuousBatchingKVCache does not support KVCache quantization.

Returns:

SIMD[DType.float32, width]

store_scale​

def store_scale(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, scales: SIMD[DType.float32])

Stores the quantization scales at the given index.

Note: ContinuousBatchingKVCache does not support KVCache quantization.

load_quantized​

def load_quantized[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[Self.dtype, width]

Loads a quantized element from the given index.

Note: ContinuousBatchingKVCache does not support KVCache quantization.

Returns:

SIMD[Self.dtype, width]

empty_cache​

def empty_cache(self) -> Bool

Returns true if the cache_lengths for all requests is 0, false otherwise.

Returns:

Bool

max_prompt_length​

def max_prompt_length(self) -> UInt32

Returns the maximum sequence length across all batches of the current request.

Returns:

UInt32

max_context_length​

def max_context_length(self) -> UInt32

Returns the maximum cache length used across all batches of the current request.

Returns:

UInt32

get_tma_row​

def get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

For non-paged caches the encoded index is already the row, so this is an identity operation.

Returns:

Int32

num_kv_rows​

def num_kv_rows(self) -> Int

Returns the total number of virtual rows in this KV cache view.

Returns:

Int

row_idx​

def row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

create_tma_tile​

def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[dtype_, swizzle_mode, kv_params_.head_size]()](self, ctx: DeviceContext) -> TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

Creates a TMA tile for this KV cache.

Returns:

TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

def create_gather4_tma_tile[*, tile_height: Int = Int(4), tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tma_dtype: DType = ContinuousBatchingKVCache[dtype_, kv_params_, blocks_origin, cache_lengths_origin, lookup_table_origin].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Creates a 2D TMA gather4 descriptor for this KV cache.

The descriptor views the KV cache as a flat 2D matrix of [num_kv_rows, tile_width] and is configured for gather4 operations that load 4 non-contiguous rows per TMA instruction. The box width is derived from the swizzle mode; for SWIZZLE_NONE it equals tile_width.

When tma_dtype differs from Self.dtype, the underlying data pointer is bitcast to tma_dtype at descriptor creation time.

Parameters:

  • ​tile_height (Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility.
  • ​tile_width (Int): Number of elements per row to load (box width) in tma_dtype elements.
  • ​tile_stride (Int): Row stride in elements in global memory. Defaults to tile_width. Use a larger value when the global row is wider than the portion to load.
  • ​swizzle_mode (TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE.
  • ​tma_dtype (DType): The data type used for the TMA descriptor. Defaults to Self.dtype. When different, the pointer is bitcast.
  • ​l2_promotion (TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.

Args:

  • ​ctx (DeviceContext): The CUDA device context used to create the TMA descriptor.

Returns:

TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A TMATensorTile with box width derived from the swizzle mode.

create_ragged_tma_tile​

def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int = padded_depth[dtype_, swizzle_mode, kv_params_.head_size]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])

Returns:

RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Not supported for ContinuousBatchingKVCache.

Returns:

TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_rope_gather4_tma_tile​

def create_rope_gather4_tma_tile[*, tile_height: Int = Int(4), tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

Not supported for ContinuousBatchingKVCache.

Returns:

TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

block_paged_ptr​

def block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[Self.dtype], MutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.dtype], MutAnyOrigin]

scales_block_paged_ptr​

def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Float32, MutAnyOrigin]

Returns a pointer to the scales block at the requested indices.

Note: ContinuousBatchingKVCache does not support KVCache quantization. This function returns a dangling pointer.

Returns:

UnsafePointer[Float32, MutAnyOrigin]

scales_raw_ptr​

def scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns a dangling pointer. ContinuousBatchingKVCache does not support quantization.

Returns:

UnsafePointer[Float32, MutAnyOrigin]