Skip to main content

Python class

KVCacheInputsPerDevice

KVCacheInputsPerDevice

class max.nn.kv_cache.KVCacheInputsPerDevice(kv_blocks, cache_lengths, lookup_table, max_lengths, kv_scales=None, attention_dispatch_metadata=None, draft_attention_dispatch_metadata=None)

source

Bases: Generic[_Tensor, _Buffer]

Symbolic graph input types for a single device’s paged KV cache.

Parameters:

  • kv_blocks (_Buffer)
  • cache_lengths (_Tensor)
  • lookup_table (_Tensor)
  • max_lengths (_Tensor)
  • kv_scales (_Buffer | None)
  • attention_dispatch_metadata (_Tensor | None)
  • draft_attention_dispatch_metadata (_Tensor | None)

attention_dispatch_metadata

attention_dispatch_metadata: _Tensor | None = None

source

cache_lengths

cache_lengths: _Tensor

source

draft_attention_dispatch_metadata

draft_attention_dispatch_metadata: _Tensor | None = None

source

flatten()

flatten()

source

Return type:

list[_Tensor | _Buffer]

flatten_without_attention_dispatch_metadata()

flatten_without_attention_dispatch_metadata()

source

Return type:

list[_Tensor | _Buffer]

kv_blocks

kv_blocks: _Buffer

source

kv_scales

kv_scales: _Buffer | None = None

source

lookup_table

lookup_table: _Tensor

source

max_lengths

max_lengths: _Tensor

source

unflatten()

unflatten(it)

source

Parameters:

it (Iterator[Any])

Return type:

KVCacheInputsPerDevice[TensorValue, BufferValue]