Python module
cache_params
KVCacheParamInterface
class max.nn.legacy.kv_cache.cache_params.KVCacheParamInterface(*args, **kwargs)
Interface for KV cache parameters.
bytes_per_block
property bytes_per_block: int
Number of bytes per cache block.
cache_strategy
cache_strategy: KVCacheStrategy
data_parallel_degree
data_parallel_degree: int
get_symbolic_inputs()
get_symbolic_inputs()
Returns the symbolic inputs for the KV cache.
-
Return type:
-
InputSymbolInterface
n_devices
n_devices: int
page_size
page_size: int
KVCacheParams
class max.nn.legacy.kv_cache.cache_params.KVCacheParams(dtype, n_kv_heads, head_dim, num_layers, devices, enable_prefix_caching=False, enable_kvcache_swapping_to_host=False, host_kvcache_swap_space_gb=None, cache_strategy=KVCacheStrategy.PAGED, page_size=128, is_mla=False, data_parallel_degree=1, n_kv_heads_per_device=0, kvcache_quant_config=None)
Configuration parameters for key-value cache management in transformer models.
This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, memory management, and cache strategy.
-
Parameters:
-
- dtype (DType)
- n_kv_heads (int)
- head_dim (int)
- num_layers (int)
- devices (Sequence[DeviceRef])
- enable_prefix_caching (bool)
- enable_kvcache_swapping_to_host (bool)
- host_kvcache_swap_space_gb (float | None)
- cache_strategy (KVCacheStrategy)
- page_size (int)
- is_mla (bool)
- data_parallel_degree (int)
- n_kv_heads_per_device (int)
- kvcache_quant_config (KVCacheQuantizationConfig | None)
bytes_per_block
property bytes_per_block: int
Returns the number of bytes per cache block.
When TP>1, each block is sharded across the devices in the tensor parallel group. This method returns the total memory needed to store a block across these devices. Includes memory needed for scales if quantization is enabled.
-
Returns:
-
The number of bytes per cache block.
cache_strategy
cache_strategy: KVCacheStrategy = 'paged'
Strategy to use for managing the KV cache.
compute_num_host_blocks()
compute_num_host_blocks()
Computes the number of blocks that can be allocated to the host.
-
Returns:
-
The number of blocks that can be allocated to the host.
-
Return type:
copy_as_dp_1()
copy_as_dp_1()
Creates a copy of the KVCacheParams with data parallelism disabled.
This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.
-
Returns:
-
A new KVCacheParams instance with data_parallel_degree set to 1.
-
Raises:
-
ValueError – If n_devices is not evenly divisible by data_parallel_degree.
-
Return type:
data_parallel_degree
data_parallel_degree: int = 1
Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).
devices
Devices to use for the KV cache.
dtype
dtype: DType
Data type for storing key and value tensors in the cache.
dtype_shorthand
property dtype_shorthand: str
Returns a shorthand textual representation of the data type.
-
Returns:
-
“bf16” for bfloat16 dtype, “f32” otherwise.
enable_kvcache_swapping_to_host
enable_kvcache_swapping_to_host: bool = False
Whether to enable swapping of KV cache blocks to host memory when device memory is full.
enable_prefix_caching
enable_prefix_caching: bool = False
Whether to enable prefix caching for efficient reuse of common prompt prefixes.
get_symbolic_inputs()
get_symbolic_inputs()
Computes the symbolic inputs for the KV cache.
This method returns a list of PagedCacheInputSymbols for each replica. This is used when constructing the model graph.
-
Returns:
-
The symbolic inputs for the KV cache.
-
Return type:
-
PagedCacheInputSymbolsByReplica
head_dim
head_dim: int
Dimensionality of each attention head.
host_kvcache_swap_space_gb
Amount of host memory (in GB) to reserve for KV cache swapping. Required when swapping is enabled.
is_mla
is_mla: bool = False
Whether the model uses Multi-Latent Attention (MLA) architecture.
kvcache_quant_config
kvcache_quant_config: KVCacheQuantizationConfig | None = None
KVCache quantization config. Currently only FP8 quantization supported.
n_devices
property n_devices: int
Returns the number of devices.
-
Returns:
-
The number of devices.
n_kv_heads
n_kv_heads: int
Total number of key-value attention heads across all devices.
n_kv_heads_per_device
n_kv_heads_per_device: int = 0
Number of KV heads allocated to each device. Computed automatically in __post_init__.
num_layers
num_layers: int
Number of layers in the model.
page_size
page_size: int = 128
Number of tokens per page (block) when using the paged cache strategy.
This value is expressed in tokens, not bytes. The byte footprint of a page is derived from pipeline configuration.
Current constraints: the page size must be a multiple of 128 and at least 128.
Required when cache_strategy is KVCacheStrategy.PAGED.
quantized_kv_cache
property quantized_kv_cache: bool
shape_per_block
Returns the shape of each cache block.
-
Returns:
-
The shape of the cache block.
shape_per_scale_block
Returns the shape of each scale block used for KVCache quantization
-
Returns:
-
The shape of the KVCache quantization scales block.
tensor_parallel_degree
property tensor_parallel_degree: int
Returns the tensor parallel degree.
-
Returns:
-
The tensor parallel degree.
KVCacheQuantizationConfig
class max.nn.legacy.kv_cache.cache_params.KVCacheQuantizationConfig(scale_dtype=float32, quantization_granularity=128)
Configuration for KVCache quantization.
Currently only FP8 Quantization is supported.
quantization_granularity
quantization_granularity: int = 128
Block-size used for KVCache quantization along head-dimension (e.g. 128).
scale_dtype
scale_dtype: DType = 81
Data type of quantization scales, if quantization is enabled
KVCacheStrategy
class max.nn.legacy.kv_cache.cache_params.KVCacheStrategy(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Enumeration of supported KV cache strategies for attention mechanisms.
This enum defines the different strategies for managing key-value caches in transformer models during inference.
MODEL_DEFAULT
MODEL_DEFAULT = 'model_default'
Use the model’s default caching strategy.
PAGED
PAGED = 'paged'
Use paged attention for efficient memory management.
kernel_substring()
kernel_substring()
Returns the common substring included in the kernel name for this caching strategy.
-
Returns:
-
The string representation of the cache strategy value.
-
Return type:
uses_opaque()
uses_opaque()
Determines if this cache strategy uses opaque cache implementations.
-
Returns:
-
True if the strategy uses opaque caching, False otherwise.
-
Return type:
MultiKVCacheParams
class max.nn.legacy.kv_cache.cache_params.MultiKVCacheParams(params, cache_strategy, page_size, data_parallel_degree, n_devices)
Aggregates multiple KV cache parameter sets.
This class implements KVCacheParamInterface by aggregating multiple KVCacheParamInterface instances. Useful for models with multiple distinct KV caches (e.g., different cache configurations for different layers).
-
Parameters:
-
- params (Sequence[KVCacheParamInterface])
- cache_strategy (KVCacheStrategy)
- page_size (int)
- data_parallel_degree (int)
- n_devices (int)
bytes_per_block
property bytes_per_block: int
Total bytes per block across all KV caches.
Since all caches allocate memory for the same sequence, the total memory cost per block is the sum across all param sets.
cache_strategy
cache_strategy: KVCacheStrategy
data_parallel_degree
data_parallel_degree: int
from_params()
classmethod from_params(*params)
-
Parameters:
-
params (KVCacheParamInterface)
-
Return type:
get_symbolic_inputs()
get_symbolic_inputs()
Returns the symbolic inputs for the KV cache.
-
Return type:
-
MultiKVCacheInputSymbols
n_devices
n_devices: int
page_size
page_size: int
params
params: Sequence[KVCacheParamInterface]
List of KV cache parameter sets to aggregate.
compute_max_seq_len_fitting_in_cache()
max.nn.legacy.kv_cache.cache_params.compute_max_seq_len_fitting_in_cache(params, available_cache_memory)
Computes the maximum sequence length that can fit in the available memory.
-
Parameters:
-
- available_cache_memory (int) – The amount of cache memory available across
- devices. (all)
- params (KVCacheParamInterface)
-
Returns:
-
The maximum sequence length that can fit in the available cache memory.
-
Return type:
compute_num_device_blocks()
max.nn.legacy.kv_cache.cache_params.compute_num_device_blocks(params, available_cache_memory, max_batch_size, max_seq_len)
Computes the number of blocks that can be allocated based on the available cache memory.
The number of blocks returned is for a single replica. Each replica will have the same number of blocks.
-
Parameters:
-
- available_cache_memory (int) – The amount of cache memory available across all devices.
- max_batch_size (int | None) – The maximum batch size, or None.
- max_seq_len (int | None) – The maximum sequence length, or None.
- params (KVCacheParamInterface)
-
Returns:
-
The number of blocks that can be allocated for a single replica.
-
Return type:
estimated_memory_size()
max.nn.legacy.kv_cache.cache_params.estimated_memory_size(params, available_cache_memory, max_batch_size, max_seq_len)
Computes the estimated memory size of the KV cache used by all replicas.
-
Parameters:
-
- available_cache_memory (int) – The amount of cache memory available across all devices.
- max_batch_size (int) – The maximum batch size.
- max_seq_len (int) – The maximum sequence length.
- params (KVCacheParamInterface)
-
Returns:
-
The estimated memory usage of the KV cache in bytes.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!