Python module
hf
ContinuousHFStaticCache
class max.pipelines.kv_cache.hf.ContinuousHFStaticCache(config: PretrainedConfig, max_batch_size: int, max_seq_len: int, device: device, dtype: dtype = torch.float32, layer_device_map: dict[int, Union[str, torch.device, int]] | None = None)
external_claim()
get_attention_mask()
release()
reset()
reset() → None
Resets the cache values while preserving the objects
set_active_slots()
set_cache_position()
set_cache_position(cache_position: Tensor)
update()
update(key_states: Tensor, value_states: Tensor, layer_idx: int, cache_kwargs: dict[str, Any] | None = None) → tuple[torch.Tensor, torch.Tensor]
Updates the cache with the new key_states and value_states for the layer layer_idx. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
-
Parameters:
- key_states (torch.Tensor) – The new key states to cache.
- value_states (torch.Tensor) – The new value states to cache.
- layer_idx (int) – The index of the layer to cache the states for.
- cache_kwargs (Dict[str, Any], optional) – Additional arguments for the cache subclass. The StaticCache needs the cache_position input to know how where to write in the cache.
-
Returns:
A tuple containing the updated key and value states.
update_attention_pattern()
update_attention_pattern(seq_id: int, attention_mask: Tensor) → None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!