Skip to main content

Python function

unflatten_ragged_attention_inputs

unflatten_ragged_attention_inputs()โ€‹

max.nn.kv_cache.unflatten_ragged_attention_inputs(kv_inputs_flat, *, n_devices)

source

Unmarshals flattened KV graph inputs into typed cache values.

Parameters:

  • kv_inputs_flat (Sequence[Any]) โ€“ Flattened graph values for all KV inputs. Elements may be Value instances or Tensor-like objects with a _graph_value attribute.
  • n_devices (int) โ€“ Number of devices represented in kv_inputs_flat.

Return type:

list[PagedCacheValues]