Skip to main content

Python module

cache_params

KVCacheParamInterface

class max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface(*args, **kwargs)

Interface for KV cache parameters.

bytes_per_block

property bytes_per_block: int

Number of bytes per cache block.

cache_strategy

cache_strategy: KVCacheStrategy

data_parallel_degree

data_parallel_degree: int

get_symbolic_inputs()

get_symbolic_inputs()

Returns the symbolic inputs for the KV cache.

Return type:

InputSymbolInterface

n_devices

n_devices: int

page_size

page_size: int

KVCacheParams

class max.nn.legacy.kv_cache.cache_params.KVCacheParams(dtype, n_kv_heads, head_dim, num_layers, devices, enable_prefix_caching=False, enable_kvcache_swapping_to_host=False, host_kvcache_swap_space_gb=None, cache_strategy=KVCacheStrategy.PAGED, page_size=128, is_mla=False, data_parallel_degree=1, n_kv_heads_per_device=0, kvcache_quant_config=None)

Configuration parameters for key-value cache management in transformer models.

This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, memory management, and cache strategy.

Parameters:

bytes_per_block

property bytes_per_block: int

Returns the number of bytes per cache block.

When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices. Includes memory needed for scales if quantization is enabled.

Returns:

The number of bytes per cache block.

cache_strategy

cache_strategy: KVCacheStrategy = 'paged'

Strategy to use for managing the KV cache.

compute_num_host_blocks()

compute_num_host_blocks()

Computes the number of blocks that can be allocated to the host.

Returns:

The number of blocks that can be allocated to the host.

Return type:

int

copy_as_dp_1()

copy_as_dp_1()

Creates a copy of the KVCacheParams with data parallelism disabled.

This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.

Returns:

A new KVCacheParams instance with data_parallel_degree set to 1.

Raises:

ValueError – If n_devices is not evenly divisible by data_parallel_degree.

Return type:

KVCacheParams

data_parallel_degree

data_parallel_degree: int = 1

Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).

devices

devices: Sequence[DeviceRef]

Devices to use for the KV cache.

dtype

dtype: DType

Data type for storing key and value tensors in the cache.

dtype_shorthand

property dtype_shorthand: str

Returns a shorthand textual representation of the data type.

Returns:

“bf16” for bfloat16 dtype, “f32” otherwise.

enable_kvcache_swapping_to_host

enable_kvcache_swapping_to_host: bool = False

Whether to enable swapping of KV cache blocks to host memory when device memory is full.

enable_prefix_caching

enable_prefix_caching: bool = False

Whether to enable prefix caching for efficient reuse of common prompt prefixes.

get_symbolic_inputs()

get_symbolic_inputs()

Computes the symbolic inputs for the KV cache.

This method returns a list of PagedCacheInputSymbols for each replica. This is used when constructing the model graph.

Returns:

The symbolic inputs for the KV cache.

Return type:

PagedCacheInputSymbolsByReplica

head_dim

head_dim: int

Dimensionality of each attention head.

host_kvcache_swap_space_gb

host_kvcache_swap_space_gb: float | None = None

Amount of host memory (in GB) to reserve for KV cache swapping. Required when swapping is enabled.

is_mla

is_mla: bool = False

Whether the model uses Multi-Latent Attention (MLA) architecture.

kvcache_quant_config

kvcache_quant_config: KVCacheQuantizationConfig | None = None

KVCache quantization config. Currently only FP8 quantization supported.

n_devices

property n_devices: int

Returns the number of devices.

Returns:

The number of devices.

n_kv_heads

n_kv_heads: int

Total number of key-value attention heads across all devices.

n_kv_heads_per_device

n_kv_heads_per_device: int = 0

Number of KV heads allocated to each device. Computed automatically in __post_init__.

num_layers

num_layers: int

Number of layers in the model.

page_size

page_size: int = 128

Number of tokens per page (block) when using the paged cache strategy.

This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration.

Current constraints: the page size must be a multiple of 128 and at least 128. Required when cache_strategy is KVCacheStrategy.PAGED.

quantized_kv_cache

property quantized_kv_cache: bool

shape_per_block

property shape_per_block: list[int]

Returns the shape of each cache block.

Returns:

The shape of the cache block.

shape_per_scale_block

property shape_per_scale_block: list[int]

Returns the shape of each scale block used for KVCache quantization

Returns:

The shape of the KVCache quantization scales block.

tensor_parallel_degree

property tensor_parallel_degree: int

Returns the tensor parallel degree.

Returns:

The tensor parallel degree.

KVCacheQuantizationConfig

class max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig(scale_dtype=float32, quantization_granularity=128)

Configuration for KVCache quantization.

Currently only FP8 Quantization is supported.

Parameters:

  • scale_dtype (DType)
  • quantization_granularity (int)

quantization_granularity

quantization_granularity: int = 128

Block-size used for KVCache quantization along head-dimension (e.g. 128).

scale_dtype

scale_dtype: DType = 81

Data type of quantization scales, if quantization is enabled

KVCacheStrategy

class max.nn.legacy.kv_cache.cache_params.KVCacheStrategy(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Enumeration of supported KV cache strategies for attention mechanisms.

This enum defines the different strategies for managing key-value caches in transformer models during inference.

MODEL_DEFAULT

MODEL_DEFAULT = 'model_default'

Use the model’s default caching strategy.

PAGED

PAGED = 'paged'

Use paged attention for efficient memory management.

kernel_substring()

kernel_substring()

Returns the common substring included in the kernel name for this caching strategy.

Returns:

The string representation of the cache strategy value.

Return type:

str

uses_opaque()

uses_opaque()

Determines if this cache strategy uses opaque cache implementations.

Returns:

True if the strategy uses opaque caching, False otherwise.

Return type:

bool

MultiKVCacheParams

class max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams(params, cache_strategy, page_size, data_parallel_degree, n_devices)

Aggregates multiple KV cache parameter sets.

This class implements KVCacheParamInterface by aggregating multiple KVCacheParamInterface instances. Useful for models with multiple distinct KV caches (e.g., different cache configurations for different layers).

Parameters:

bytes_per_block

property bytes_per_block: int

Total bytes per block across all KV caches.

Since all caches allocate memory for the same sequence, the total memory cost per block is the sum across all param sets.

cache_strategy

cache_strategy: KVCacheStrategy

data_parallel_degree

data_parallel_degree: int

from_params()

classmethod from_params(*params)

Parameters:

params (KVCacheParamInterface)

Return type:

MultiKVCacheParams

get_symbolic_inputs()

get_symbolic_inputs()

Returns the symbolic inputs for the KV cache.

Return type:

MultiKVCacheInputSymbols

n_devices

n_devices: int

page_size

page_size: int

params

params: Sequence[KVCacheParamInterface]

List of KV cache parameter sets to aggregate.

compute_max_seq_len_fitting_in_cache()

max.nn.legacy.kv_cache.cache_params.compute_max_seq_len_fitting_in_cache(params, available_cache_memory)

Computes the maximum sequence length that can fit in the available memory.

Parameters:

  • available_cache_memory (int) – The amount of cache memory available across
  • devices. (all)
  • params (KVCacheParamInterface)

Returns:

The maximum sequence length that can fit in the available cache memory.

Return type:

int

compute_num_device_blocks()

max.nn.legacy.kv_cache.cache_params.compute_num_device_blocks(params, available_cache_memory, max_batch_size, max_seq_len)

Computes the number of blocks that can be allocated based on the available cache memory.

The number of blocks returned is for a single replica. Each replica will have the same number of blocks.

Parameters:

  • available_cache_memory (int) – The amount of cache memory available across all devices.
  • max_batch_size (int | None) – The maximum batch size, or None.
  • max_seq_len (int | None) – The maximum sequence length, or None.
  • params (KVCacheParamInterface)

Returns:

The number of blocks that can be allocated for a single replica.

Return type:

int

estimated_memory_size()

max.nn.legacy.kv_cache.cache_params.estimated_memory_size(params, available_cache_memory, max_batch_size, max_seq_len)

Computes the estimated memory size of the KV cache used by all replicas.

Parameters:

  • available_cache_memory (int) – The amount of cache memory available across all devices.
  • max_batch_size (int) – The maximum batch size.
  • max_seq_len (int) – The maximum sequence length.
  • params (KVCacheParamInterface)

Returns:

The estimated memory usage of the KV cache in bytes.

Return type:

int

Was this page helpful?