Skip to main content
Log in

Python module

hf

ContinuousHFStaticCache

class max.pipelines.kv_cache.hf.ContinuousHFStaticCache(config: PretrainedConfig, max_batch_size: int, max_seq_len: int, device: device, dtype: dtype = torch.float32, layer_device_map: dict[int, Union[str, torch.device, int]] | None = None)

external_claim()

external_claim(seq_ids: list[int]) → None

get_attention_mask()

get_attention_mask(seq_ids: list[int]) → Tensor

release()

release(seq_id: int) → None

reset()

reset() → None

Resets the cache values while preserving the objects

set_active_slots()

set_active_slots(seq_ids: list[int]) → None

set_cache_position()

set_cache_position(cache_position: Tensor)

update()

update(key_states: Tensor, value_states: Tensor, layer_idx: int, cache_kwargs: dict[str, Any] | None = None) → tuple[torch.Tensor, torch.Tensor]

Updates the cache with the new key_states and value_states for the layer layer_idx. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.

  • Parameters:

    • key_states (torch.Tensor) – The new key states to cache.
    • value_states (torch.Tensor) – The new value states to cache.
    • layer_idx (int) – The index of the layer to cache the states for.
    • cache_kwargs (Dict[str, Any], optional) – Additional arguments for the cache subclass. The StaticCache needs the cache_position input to know how where to write in the cache.
  • Returns:

    A tuple containing the updated key and value states.

update_attention_pattern()

update_attention_pattern(seq_id: int, attention_mask: Tensor) → None