Skip to main content

Python class

KVTransferEngine

KVTransferEngine

class max.kv_cache.KVTransferEngine(name, tensors, *, total_num_pages, replicate_kv_across_tp=False)

source

Bases: object

KVCache Transfer Engine with support for Data Parallelism (DP) and Tensor Parallelism (TP).

The engine accepts a 2D list of tensors: list[list[Buffer]] where the outer list represents DP replicas and the inner list represents TP shards within each replica.

The TransferEngine communicates with other TransferEngines in other threads or processes. However, individual TransferEngines themselves are not thread-safe. It is intended to be used by MAX’s single-threaded scheduler.

Parameters:

bytes_per_page

bytes_per_page: int

source

Bytes per page for each tensor.

cleanup()

cleanup()

source

Release all resources associated with the transfer engine.

Should be called before the transfer engine is garbage collected. Moving this logic into the __del__ destructor does causes a UCX error for unknown reasons.

Return type:

None

cleanup_transfer()

cleanup_transfer(transfer_req)

source

Cleanup a transfer. This should be called after a transfer is complete.

Parameters:

transfer_req (TransferReqData) – The transfer request to cleanup.

Return type:

None

completed_recv_transfers

completed_recv_transfers: dict[str, dict[str, int]]

source

Map of agent names to completed recv transfers.

connect()

connect(remote)

source

Connect to a remote engine (all replicas).

Parameters:

remote (KVTransferEngineMetadata) – Metadata for the remote engine (all replicas).

Return type:

None

disconnect()

disconnect(name)

source

Tear down a single remote connection.

Releases inflight transfer handles referencing this remote, invalidates NIXL metadata, and removes bookkeeping entries. After disconnect, connect() will accept the same name again.

Parameters:

name (str) – The name of the remote engine to disconnect.

Raises:

ValueError – If the named remote is not currently connected.

Return type:

None

dp

dp: int

source

Number of DP replicas.

from_paged_kv_cache()

classmethod from_paged_kv_cache(name, kv_cache)

source

Construct an engine wired to a PagedKVCacheManager.

Pulls the per-replica device buffers, sets total_num_pages, and derives replicate_kv_across_tp from is_mla on the primary cache params. Equivalent to constructing the engine manually but consolidates the boilerplate that prefill/decode schedulers share.

Parameters:

Return type:

KVTransferEngine

inflight_send_transfers

inflight_send_transfers: dict[str, TransferReqData]

source

Map of transfer names to send transfer request data.

initiate_read_transfer()

initiate_read_transfer(remote_metadata, src_idxs, dst_idxs, src_replica_idx, dst_replica_idx, tp_shard_limit=None)

source

Initiate a READ transfer from remote engine to current engine.

The current engine pulls data from the remote. Used by DKVConnector to read KV blocks from BlockStore DRAM into GPU VRAM.

Parameters:

  • remote_metadata (KVTransferEngineMetadata) – Metadata for the remote engine (source).
  • src_idxs (list[int]) – Page indices in the remote engine (source).
  • dst_idxs (list[int]) – Page indices in the current engine (destination).
  • src_replica_idx (int) – Replica index in the remote engine.
  • dst_replica_idx (int) – Replica index in the current engine.
  • tp_shard_limit (int | None) – If set, only the first N TP shards transfer.

Return type:

TransferReqData

initiate_send_transfer()

initiate_send_transfer(remote_metadata, src_idxs, dst_idxs, src_replica_idx, dst_replica_idx, tp_shard_limit=None)

source

Initiate a transfer from current engine to remote engine.

The same page indices are broadcast to all TP shards within the source and destination replicas.

Parameters:

  • remote_metadata (KVTransferEngineMetadata) – Metadata for the remote engine.
  • src_idxs (list[int]) – List of indices of the source pages in the current engine.
  • dst_idxs (list[int]) – List of indices of the destination pages in the remote engine.
  • src_replica_idx (int) – Index of the source replica to transfer from.
  • dst_replica_idx (int) – Index of the destination replica to transfer to.
  • tp_shard_limit (int | None) – Maximum number of TP shards to transfer. When set, only the first tp_shard_limit shards participate in the transfer. Useful for MLA models where KV data is identical across shards.

Return type:

TransferReqData

is_complete()

is_complete(transfer_req)

source

Checks if a given send, recv, or read transfer is completed.

Parameters:

transfer_req (TransferReqData) – The transfer request.

Returns:

True if all transfers have completed; false otherwise.

Return type:

bool

memory_type

memory_type: MemoryType

source

Type of memory being managed (e.g. DRAM).

metadata

property metadata: KVTransferEngineMetadata

source

Get metadata for all replicas.

Returns:

Metadata for the entire engine (all replicas).

name

name: str

source

Name of transfer engine / nixl agent.

register_tensor_group()

register_tensor_group(name, tensors, total_num_pages)

source

Register an additional tensor group on all agents.

The new buffers are registered as extra memory regions on the existing NIXL agents. Future initiate_send_transfer calls will automatically include descriptors for this group alongside the primary tensor, bundling both into a single NIXL transfer.

Parameters:

  • name (str) – Group name (e.g. "draft").
  • tensors (Sequence[Sequence[Buffer]]) – 2D buffer grid [replica][tp_shard] matching the primary tensor layout.
  • total_num_pages (int) – Number of pages in each buffer (same page count as the primary tensor — page size may differ).

Return type:

None

remote_agent_to_engine

remote_agent_to_engine: dict[str, str]

source

Map of remote agent names to their engine names.

remote_connections

remote_connections: dict[str, KVTransferEngineMetadata]

source

Map of remote engine names to their metadata.

replicate_kv_across_tp

replicate_kv_across_tp: bool

source

Whether KV is replicated across TP ranks (MLA).

sync_and_release()

sync_and_release(transfer_req, timeout_s=30.0)

source

Waits for a transfer to complete and releases it.

Parameters:

  • transfer_req (TransferReqData) – The transfer request to wait on.
  • timeout_s (float) – Maximum seconds to wait before raising TimeoutError.

Raises:

TimeoutError – If the transfer does not complete within timeout_s.

Return type:

None

tensor_agents

tensor_agents: list[list[TensorAgent]]

source

[replica][tp_shard].

Type:

2D list of TensorAgent objects

total_num_pages

total_num_pages: int

source

Total number of pages in each tensor (same across all replicas).

tp

tp: int

source

Number of TP shards per replica.