Skip to main content

Python module

hf

ContinuousHFStaticCache

class max.nn.kv_cache.hf.ContinuousHFStaticCache(config, max_batch_size, max_seq_len, device, dtype=torch.float32, layer_device_map=None)

Parameters:

  • config (PretrainedConfig)
  • max_batch_size (int)
  • max_seq_len (int)
  • device (device)
  • dtype (dtype)
  • layer_device_map (dict[int, str | device | int] | None)

external_claim()

external_claim(seq_ids)

Parameters:

seq_ids (list[int])

Return type:

None

get_attention_mask()

get_attention_mask(seq_ids)

Parameters:

seq_ids (list[int])

Return type:

Tensor

release()

release(seq_id)

Parameters:

seq_id (int)

Return type:

None

reset()

reset()

Resets the cache values while preserving the objects

Return type:

None

set_active_slots()

set_active_slots(seq_ids)

Parameters:

seq_ids (list[int])

Return type:

None

set_cache_position()

set_cache_position(cache_position)

Parameters:

cache_position (Tensor)

Return type:

None

update()

update(key_states, value_states, layer_idx, cache_kwargs=None)

Updates the cache with the new key_states and value_states for the layer layer_idx. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.

Parameters:

  • key_states (torch.Tensor) – The new key states to cache.
  • value_states (torch.Tensor) – The new value states to cache.
  • layer_idx (int) – The index of the layer to cache the states for.
  • cache_kwargs (Dict[str, Any], optional) – Additional arguments for the cache subclass. The StaticCache needs the cache_position input to know how where to write in the cache.

Returns:

A tuple containing the updated key and value states.

Return type:

tuple[Tensor, Tensor]

update_attention_pattern()

update_attention_pattern(seq_id, attention_mask)

Parameters:

  • seq_id (int)
  • attention_mask (Tensor)

Return type:

None