Skip to main content

Python module

cache_params

KVCacheParams

class max.nn.kv_cache.cache_params.KVCacheParams(dtype, n_kv_heads, head_dim, num_layers, enable_prefix_caching=False, enable_kvcache_swapping_to_host=False, host_kvcache_swap_space_gb=None, cache_strategy=KVCacheStrategy.PAGED, page_size=128, n_devices=1, is_mla=False, data_parallel_degree=1, n_kv_heads_per_device=0)

Configuration parameters for key-value cache management in transformer models.

This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, memory management, and cache strategy.

Parameters:

  • dtype (DType)
  • n_kv_heads (int)
  • head_dim (int)
  • num_layers (int)
  • enable_prefix_caching (bool)
  • enable_kvcache_swapping_to_host (bool)
  • host_kvcache_swap_space_gb (float | None)
  • cache_strategy (KVCacheStrategy)
  • page_size (int)
  • n_devices (int)
  • is_mla (bool)
  • data_parallel_degree (int)
  • n_kv_heads_per_device (int)

bytes_per_block

property bytes_per_block: int

Returns the number of bytes per cache block.

When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices.

Returns:

The number of bytes per cache block.

cache_strategy

cache_strategy: KVCacheStrategy = 'paged'

Strategy to use for managing the KV cache.

compute_max_seq_len_fitting_in_cache()

compute_max_seq_len_fitting_in_cache(available_cache_memory)

Computes the maximum sequence length that can fit in the available cache memory.

Parameters:

available_cache_memory (int) – The amount of cache memory available across all devices.

Returns:

The maximum sequence length that can fit in the available cache memory.

Return type:

int

compute_num_device_blocks()

compute_num_device_blocks(available_cache_memory, max_batch_size, max_seq_len)

Computes the number of blocks that can be allocated based on the available cache memory.

The number of blocks returned is for a single replica. Each replica will have the same number of blocks.

Parameters:

  • available_cache_memory (int) – The amount of cache memory available across all devices.
  • max_batch_size (int | None) – The maximum batch size, or None.
  • max_seq_len (int | None) – The maximum sequence length, or None.

Returns:

The number of blocks that can be allocated for a single replica.

Return type:

int

compute_num_host_blocks()

compute_num_host_blocks()

Computes the number of blocks that can be allocated to the host.

Returns:

The number of blocks that can be allocated to the host.

Return type:

int

copy_as_dp_1()

copy_as_dp_1()

Creates a copy of the KVCacheParams with data parallelism disabled.

This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.

Returns:

A new KVCacheParams instance with data_parallel_degree set to 1.

Raises:

ValueError – If n_devices is not evenly divisible by data_parallel_degree.

Return type:

KVCacheParams

data_parallel_degree

data_parallel_degree: int = 1

Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).

dtype

dtype: DType

Data type for storing key and value tensors in the cache.

dtype_shorthand

property dtype_shorthand: str

Returns a shorthand textual representation of the data type.

Returns:

“bf16” for bfloat16 dtype, “f32” otherwise.

enable_kvcache_swapping_to_host

enable_kvcache_swapping_to_host: bool = False

Whether to enable swapping of KV cache blocks to host memory when device memory is full.

enable_prefix_caching

enable_prefix_caching: bool = False

Whether to enable prefix caching for efficient reuse of common prompt prefixes.

estimated_memory_size()

estimated_memory_size(available_cache_memory, max_batch_size, max_seq_len)

Computes the estimated memory size of the KV cache used by all replicas.

Parameters:

  • available_cache_memory (int) – The amount of cache memory available across all devices.
  • max_batch_size (int) – The maximum batch size.
  • max_seq_len (int) – The maximum sequence length.

Returns:

The estimated memory usage of the KV cache in bytes.

Return type:

int

head_dim

head_dim: int

Dimensionality of each attention head.

host_kvcache_swap_space_gb

host_kvcache_swap_space_gb: float | None = None

Amount of host memory (in GB) to reserve for KV cache swapping. Required when swapping is enabled.

is_mla

is_mla: bool = False

Whether the model uses Multi-Latent Attention (MLA) architecture.

n_devices

n_devices: int = 1

Total number of devices (GPUs/accelerators) available for inference.

n_kv_heads

n_kv_heads: int

Total number of key-value attention heads across all devices.

n_kv_heads_per_device

n_kv_heads_per_device: int = 0

Number of KV heads allocated to each device. Computed automatically in __post_init__.

num_layers

num_layers: int

Number of layers in the model.

page_size

page_size: int = 128

Number of tokens per page (block) when using the paged cache strategy.

This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration.

Current constraints: the page size must be a multiple of 128 and at least 128. Required when cache_strategy is KVCacheStrategy.PAGED.

shape_per_block

property shape_per_block: list[int]

Returns the shape of each cache block.

Returns:

The shape of the cache block.

tensor_parallel_degree

property tensor_parallel_degree: int

Returns the tensor parallel degree.

Returns:

The tensor parallel degree.

KVCacheStrategy

class max.nn.kv_cache.cache_params.KVCacheStrategy(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Enumeration of supported KV cache strategies for attention mechanisms.

This enum defines the different strategies for managing key-value caches in transformer models during inference.

MODEL_DEFAULT

MODEL_DEFAULT = 'model_default'

Use the model’s default caching strategy.

PAGED

PAGED = 'paged'

Use paged attention for efficient memory management.

kernel_substring()

kernel_substring()

Returns the common substring included in the kernel name for this caching strategy.

Returns:

The string representation of the cache strategy value.

Return type:

str

uses_opaque()

uses_opaque()

Determines if this cache strategy uses opaque cache implementations.

Returns:

True if the strategy uses opaque caching, False otherwise.

Return type:

bool

Was this page helpful?