Skip to main content

Python class

KVCacheParams

KVCacheParams

class max.nn.kv_cache.KVCacheParams(dtype, n_kv_heads, head_dim, num_layers, devices, enable_prefix_caching=False, kv_connector=None, kv_connector_config=None, host_kvcache_swap_space_gb=None, page_size=128, is_mla=False, num_q_heads=None, data_parallel_degree=1, n_kv_heads_per_device=0, num_q_heads_per_device=None, kvcache_quant_config=None, num_eagle_speculative_tokens=0)

source

Bases: KVCacheParamInterface

Configuration parameters for key-value cache management in transformer models.

This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, and memory management.

Parameters:

  • dtype (DType)
  • n_kv_heads (int)
  • head_dim (int)
  • num_layers (int)
  • devices (Sequence[DeviceRef])
  • enable_prefix_caching (bool)
  • kv_connector (KVConnectorType | None)
  • kv_connector_config (Any)
  • host_kvcache_swap_space_gb (float | None)
  • page_size (int)
  • is_mla (bool)
  • num_q_heads (int | None)
  • data_parallel_degree (int)
  • n_kv_heads_per_device (int)
  • num_q_heads_per_device (int | None)
  • kvcache_quant_config (KVCacheQuantizationConfig | None)
  • num_eagle_speculative_tokens (int)

allocate_buffers()

allocate_buffers(total_num_pages)

source

Allocates the buffers for the KV cache.

Parameters:

total_num_pages (int)

Return type:

list[KVCacheBuffer]

bytes_per_block

property bytes_per_block: int

source

Returns the number of bytes per cache block.

When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices. Includes memory needed for scales if quantization is enabled.

Returns:

The number of bytes per cache block.

copy_as_dp_1()

copy_as_dp_1(replica_idx=0)

source

Creates a copy of the KVCacheParams with data parallelism disabled.

This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.

Returns:

A new KVCacheParams instance with data_parallel_degree set to 1.

Raises:

ValueError – If n_devices is not evenly divisible by data_parallel_degree.

Parameters:

replica_idx (int)

Return type:

KVCacheParams

data_parallel_degree

data_parallel_degree: int = 1

source

Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).

devices

devices: Sequence[DeviceRef]

source

Devices to use for the KV cache.

dtype

dtype: DType

source

Data type for storing key and value tensors in the cache.

dtype_shorthand

property dtype_shorthand: str

source

Returns a shorthand textual representation of the data type.

Returns:

“bf16” for bfloat16 dtype, “f32” otherwise.

enable_prefix_caching

enable_prefix_caching: bool = False

source

Whether to enable prefix caching for efficient reuse of common prompt prefixes.

get_symbolic_inputs()

get_symbolic_inputs(prefix='')

source

Computes the symbolic inputs for the KV cache.

This method returns a list of KVCacheInputs for each replica. This is used when constructing the model graph.

Returns:

The symbolic inputs for the KV cache.

Parameters:

prefix (str)

Return type:

KVCacheInputs[TensorType, BufferType]

head_dim

head_dim: int

source

Dimensionality of each attention head.

host_kvcache_swap_space_gb

host_kvcache_swap_space_gb: float | None = None

source

Amount of host memory (in GB) to reserve for KV cache swapping. Required when local or tiered connector is used.

is_fp8_kv_dtype

property is_fp8_kv_dtype: bool

source

Whether the KV cache stores FP8 data, for dispatch resolution.

Unlike quantized_kv_cache (which also requires valid scale config), this checks only the storage dtype—matching the compile-time detection in the MLA decode kernel.

TODO(SERVOPT-1094): Once SnapMLA uses a valid scale_dtype, this can be replaced by quantized_kv_cache.

is_mla

is_mla: bool = False

source

Whether the model uses Multi-Latent Attention (MLA) architecture.

kv_connector

kv_connector: KVConnectorType | None = None

source

Type of KV cache connector to use (null, local, tiered, lmcache).

kv_connector_config

kv_connector_config: Any = None

source

Connector-specific configuration (KVConnectorConfig from the pipelines layer).

kvcache_quant_config

kvcache_quant_config: KVCacheQuantizationConfig | None = None

source

KVCache quantization config. Currently only FP8 quantization supported.

n_devices

property n_devices: int

source

Returns the number of devices.

Returns:

The number of devices.

n_kv_heads

n_kv_heads: int

source

Total number of key-value attention heads across all devices.

n_kv_heads_per_device

n_kv_heads_per_device: int = 0

source

Number of KV heads allocated to each device. Computed automatically in __post_init__.

num_eagle_speculative_tokens

num_eagle_speculative_tokens: int = 0

source

Number of draft tokens to generate for EAGLE speculative decoding.

num_layers

num_layers: int

source

Number of layers in the model.

num_q_heads

num_q_heads: int | None = None

source

Number of query attention heads. Required when is_mla is True so that the attention dispatch resolver can call the MLA-specific kernel.

num_q_heads_per_device

num_q_heads_per_device: int | None = None

source

Number of query heads per device. Computed automatically in __post_init__ from num_q_heads and the parallelism configuration.

page_size

page_size: int = 128

source

Number of tokens per page (block).

This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration.

Current constraints: the page size must be a multiple of 128 and at least 128.

quantized_kv_cache

property quantized_kv_cache: bool

source

Returns whether FP8 KV cache quantization is enabled.

Returns:

True when the cache dtype is float8_e4m3fn or float8_e4m3fnuz and a valid quantization scale dtype is configured; False otherwise.

shape_per_block

property shape_per_block: list[int]

source

Returns the shape of each cache block.

Returns:

The shape of the cache block.

shape_per_scale_block

property shape_per_scale_block: list[int]

source

Returns the shape of each scale block used for KVCache quantization

Returns:

The shape of the KVCache quantization scales block.

tensor_parallel_degree

property tensor_parallel_degree: int

source

Returns the tensor parallel degree.

Returns:

The tensor parallel degree.