IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

KVCacheParams

KVCacheParams​

class max.nn.kv_cache.KVCacheParams(dtype, n_kv_heads, head_dim, num_layers, devices, enable_prefix_caching=False, kv_connector=None, kv_connector_config=None, host_kvcache_swap_space_gb=None, page_size=128, is_mla=False, num_q_heads=None, data_parallel_degree=1, n_kv_heads_per_device=0, num_q_heads_per_device=None, kvcache_quant_config=None, speculative_method=None, num_draft_tokens=0)

source

Bases: KVCacheParamInterface

Configuration parameters for key-value cache management in transformer models.

This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, and memory management.

Parameters:

allocate_buffers()​

allocate_buffers(total_num_pages)

source

Allocates the buffers for the KV cache.

Parameters:

total_num_pages (int)

Return type:

list[KVCacheBuffer]

bytes_per_block​

property bytes_per_block: int

source

Returns the number of bytes per cache block.

When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices. Includes memory needed for scales if quantization is enabled.

Returns:

The number of bytes per cache block.

copy_as_dp_1()​

copy_as_dp_1(replica_idx=0)

source

Creates a copy of the KVCacheParams with data parallelism disabled.

This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.

Returns:

A new KVCacheParams instance with data_parallel_degree set to 1.

Raises:

ValueError – If n_devices is not evenly divisible by data_parallel_degree.

Parameters:

replica_idx (int)

Return type:

KVCacheParams

data_parallel_degree​

data_parallel_degree: int = 1

source

Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).

devices​

devices: Sequence[DeviceRef]

source

Devices to use for the KV cache.

dtype​

dtype: DType

source

Data type for storing key and value tensors in the cache.

dtype_shorthand​

property dtype_shorthand: str

source

Returns a shorthand textual representation of the data type.

Returns:

β€œbf16” for bfloat16 dtype, β€œf32” otherwise.

enable_prefix_caching​

enable_prefix_caching: bool = False

source

Whether to enable prefix caching for efficient reuse of common prompt prefixes.

get_symbolic_inputs()​

get_symbolic_inputs(prefix='', *, draft_attention_group=None)

source

Computes the symbolic inputs for the KV cache.

Parameters:

  • prefix (str) – Prefix for dynamic dim names.
  • draft_attention_group (KVCacheParams | None) – When set, sizes draft_attention_dispatch_metadata by the drafter’s is_mla rather than self’s. Use for unified spec-dec graphs with asymmetric attention types.

Returns:

The symbolic inputs for the KV cache.

Return type:

KVCacheInputs[TensorType, BufferType]

graph_capture_probe_cache_lengths()​

graph_capture_probe_cache_lengths(max_cache_length, q_max_seq_len=1)

source

Returns the cache lengths to probe during decode graph capture.

Parameters:

  • max_cache_length (int) – Upper bound on the cache length to probe.
  • q_max_seq_len (int) – Per-step query width (affects MLA spec-decode probing).

Returns:

The cache lengths to probe, one per distinct dispatch mode to capture.

Return type:

list[int]

head_dim​

head_dim: int

source

Dimensionality of each attention head.

host_kvcache_swap_space_gb​

host_kvcache_swap_space_gb: float | None = None

source

Amount of host memory (in GB) to reserve for KV cache swapping. Required when local or tiered connector is used.

is_fp8_kv_dtype​

property is_fp8_kv_dtype: bool

source

Whether the KV cache stores FP8 data, for dispatch resolution.

Unlike quantized_kv_cache (which also requires valid scale config), this checks only the storage dtypeβ€”matching the compile-time detection in the MLA decode kernel.

TODO(SERVOPT-1094): Once SnapMLA uses a valid scale_dtype, this can be replaced by quantized_kv_cache.

is_mla​

is_mla: bool = False

source

Whether the model uses Multi-Latent Attention (MLA) architecture.

kv_connector​

kv_connector: KVConnectorType | None = None

source

Type of KV cache connector to use (null, local, tiered, dkv).

kv_connector_config​

kv_connector_config: Any = None

source

Connector-specific configuration (KVConnectorConfig from the pipelines layer).

kvcache_quant_config​

kvcache_quant_config: KVCacheQuantizationConfig | None = None

source

KVCache quantization config. Currently only FP8 quantization supported.

n_devices​

property n_devices: int

source

Returns the number of devices.

Returns:

The number of devices.

n_kv_heads​

n_kv_heads: int

source

Total number of key-value attention heads across all devices.

n_kv_heads_per_device​

n_kv_heads_per_device: int = 0

source

Number of KV heads allocated to each device. Computed automatically in __post_init__.

num_draft_tokens​

num_draft_tokens: int = 0

source

Total draft tokens generated per speculative iteration.

Zero when no speculative decoding is configured.

num_layers​

num_layers: int

source

Number of layers in the model.

num_q_heads​

num_q_heads: int | None = None

source

Number of query attention heads. Required when is_mla is True so that the attention dispatch resolver can call the MLA-specific kernel.

num_q_heads_per_device​

num_q_heads_per_device: int | None = None

source

Number of query heads per device. Computed automatically in __post_init__ from num_q_heads and the parallelism configuration.

page_size​

page_size: int = 128

source

Number of tokens per page (block).

This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration.

Current constraints: the page size must be a multiple of 128 and at least 128.

quantized_kv_cache​

property quantized_kv_cache: bool

source

Returns whether FP8 KV cache quantization is enabled.

Returns:

True when the cache dtype is float8_e4m3fn or float8_e4m3fnuz and a valid quantization scale dtype is configured; False otherwise.

replicates_kv_across_tp​

property replicates_kv_across_tp: bool

source

Whether every device holds identical KV state.

resolve_attn_key()​

resolve_attn_key(batch_size, max_prompt_length, max_cache_valid_length)

source

Resolves the decode attention dispatch key for the given shape.

Parameters:

  • batch_size (int) – Number of requests in the decode batch.
  • max_prompt_length (int) – Per-step query width (1 for plain decode, 1 + num_spec_tokens for speculative verify).
  • max_cache_valid_length (int) – Maximum valid cache length in the batch.

Returns:

The resolved AttnKey (an MHAAttnKey or MLAAttnKey).

Return type:

AttnKey

shape_per_block​

property shape_per_block: list[int]

source

Returns the shape of each cache block.

Returns:

The shape of the cache block.

shape_per_scale_block​

property shape_per_scale_block: list[int]

source

Returns the shape of each scale block used for KVCache quantization

Returns:

The shape of the KVCache quantization scales block.

speculative_method​

speculative_method: Literal['eagle', 'mtp', 'dflash'] | None = None

source

Speculative decoding method propagated from SpeculativeConfig

tensor_parallel_degree​

property tensor_parallel_degree: int

source

Returns the tensor parallel degree.

Returns:

The tensor parallel degree.