Skip to main content

Python module

transfer_engine

KVCache Transfer Engine

KVTransferEngine

class max.kv_cache.paged_cache.transfer_engine.KVTransferEngine(name, tensors, *, total_num_pages)

KVCache Transfer Engine with support for Data Parallelism (DP) and Tensor Parallelism (TP).

The engine accepts a 2D list of tensors: list[list[Tensor]] where the outer list represents DP replicas and the inner list represents TP shards within each replica.

The TransferEngine communicates with other TransferEngines in other threads or processes. However, individual TransferEngines themselves are not thread-safe. It is intended to be used by MAX’s single-threaded scheduler.

Parameters:

  • name (str)
  • tensors (Sequence[Sequence[Tensor]])
  • total_num_pages (int)

bytes_per_page

bytes_per_page: int

Bytes per page for each tensor.

cleanup()

cleanup()

Release all resources associated with the transfer engine.

Should be called before the transfer engine is garbage collected. Moving this logic into the __del__ destructor does causes a UCX error for unknown reasons.

Return type:

None

cleanup_transfer()

cleanup_transfer(transfer_req)

Cleanup a transfer. This should be called after a transfer is complete.

Parameters:

transfer_req (TransferReqData) – The transfer request to cleanup.

Return type:

None

completed_recv_transfers

completed_recv_transfers: dict[str, dict[str, int]]

Map of agent names to completed recv transfers.

connect()

connect(remote)

Connect to a remote engine (all replicas).

Parameters:

remote (KVTransferEngineMetadata) – Metadata for the remote engine (all replicas).

Return type:

None

dp

dp: int

Number of DP replicas.

inflight_send_transfers

inflight_send_transfers: dict[str, TransferReqData]

Map of transfer names to send transfer request data.

initiate_send_transfer()

initiate_send_transfer(remote_metadata, src_idxs, dst_idxs, src_replica_idx, dst_replica_idx)

Initiate a transfer from current engine to remote engine.

The same page indices are broadcast to all TP shards within the source and destination replicas.

Parameters:

  • remote_metadata (KVTransferEngineMetadata) – Metadata for the remote engine.
  • src_idxs (list[int]) – List of indices of the source pages in the current engine.
  • dst_idxs (list[int]) – List of indices of the destination pages in the remote engine.
  • src_replica_idx (int) – Index of the source replica to transfer from.
  • dst_replica_idx (int) – Index of the destination replica to transfer to.

Return type:

TransferReqData

is_complete()

is_complete(transfer_req)

Checks if a given send or recv transfer is completed.

Parameters:

transfer_req (TransferReqData) – The transfer request.

Returns:

True if all transfers have completed; false otherwise.

Return type:

bool

memory_type

memory_type: MemoryType

Type of memory being managed (e.g. DRAM).

metadata

property metadata: KVTransferEngineMetadata

Get metadata for all replicas.

Returns:

Metadata for the entire engine (all replicas).

name

name: str

Name of transfer engine / nixl agent.

remote_agent_to_engine

remote_agent_to_engine: dict[str, str]

Map of remote agent names to their engine names.

remote_connections

remote_connections: dict[str, KVTransferEngineMetadata]

Map of remote engine names to their metadata.

sync_and_release()

sync_and_release(transfer_req)

Wait for a transfer to complete and release the transfer after it completes.

Parameters:

transfer_req (TransferReqData)

Return type:

None

tensor_agents

tensor_agents: list[list[TensorAgent]]

[replica][tp_shard].

Type:

2D list of TensorAgent objects

total_num_pages

total_num_pages: int

Total number of pages in each tensor (same across all replicas).

tp

tp: int

Number of TP shards per replica.

KVTransferEngineMetadata

class max.kv_cache.paged_cache.transfer_engine.KVTransferEngineMetadata(*, name, total_num_pages, bytes_per_page, memory_type, hostname, agents_meta)

Metadata associated with a transfer engine.

This is safe to send between threads/processes.

Parameters:

agents_meta

agents_meta: list[list[TensorAgentMetadata]]

[replica][tp_shard].

Type:

Metadata for each replica’s agents

bytes_per_page

bytes_per_page: int

Bytes per page for each tensor.

hostname

hostname: str

Hostname of the machine that the transfer engine is running on.

memory_type

memory_type: MemoryType

Memory type of the transfer engine.

name

name: str

Base name of the transfer engine.

total_num_pages

total_num_pages: int

Total number of pages in each tensor.

TensorAgent

class max.kv_cache.paged_cache.transfer_engine.TensorAgent(agent, agent_name, tensor, base_addr, ucx_backend, device_id, agent_metadata, reg_dlist)

Manages a single tensor and its associated NIXL agent for transfers.

This class holds both the runtime state (live objects) and can generate the serializable metadata for communication between engines.

Parameters:

  • agent (Agent)
  • agent_name (str)
  • tensor (Tensor)
  • base_addr (int)
  • ucx_backend (int)
  • device_id (int)
  • agent_metadata (bytes)
  • reg_dlist (RegistrationDescriptorList)

agent

agent: Agent

NIXL agent for this tensor.

agent_metadata

agent_metadata: bytes

Metadata for this agent.

agent_name

agent_name: str

Name of this agent.

base_addr

base_addr: int

Base memory address for this tensor.

create_agent()

classmethod create_agent(agent_name, listen_port, tensor, total_num_pages, elts_per_page, memory_type)

Parameters:

  • agent_name (str)
  • listen_port (int)
  • tensor (Tensor)
  • total_num_pages (int)
  • elts_per_page (int)
  • memory_type (MemoryType)

Return type:

TensorAgent

device_id

device_id: int

Device ID for this tensor.

reg_dlist

reg_dlist: RegistrationDescriptorList

Registration descriptor list for this tensor.

tensor

tensor: Tensor

Tensor for this agent.

to_metadata()

to_metadata()

Convert to serializable metadata for communication.

Return type:

TensorAgentMetadata

ucx_backend

ucx_backend: int

UCX backend for this tensor.

TensorAgentMetadata

class max.kv_cache.paged_cache.transfer_engine.TensorAgentMetadata(*, agent_name, metadata, base_addr, device_id)

Metadata for a single tensor/agent in the transfer engine.

This is used for serialization and communication between engines.

Parameters:

agent_name

agent_name: str

Name of this agent.

base_addr

base_addr: int

Base memory address for this tensor.

device_id

device_id: int

Device ID for this tensor.

metadata

metadata: bytes

Metadata for this agent.

TransferReqData

class max.kv_cache.paged_cache.transfer_engine.TransferReqData(*, dst_name, src_name, transfer_name, transfer_ids, src_idxs, dst_idxs, src_replica_idx, dst_replica_idx)

Metadata associated with a transfer request.

This is safe to send between threads/processes.

Parameters:

dst_idxs

dst_idxs: list[int]

Length of destination indices can differ from len(transfer_ids).

dst_name

dst_name: str

Base name of destination engine.

dst_replica_idx

dst_replica_idx: int

Index of the destination replica this transfer is to.

src_idxs

src_idxs: list[int]

Length of source indices can differ from len(transfer_ids).

src_name

src_name: str

Base name of source engine.

src_replica_idx

src_replica_idx: int

Index of the source replica this transfer is from.

transfer_ids

transfer_ids: list[int]

Transfer IDs (one per TP shard in the replica).

transfer_name

transfer_name: str

Transfer name.

available_port()

max.kv_cache.paged_cache.transfer_engine.available_port(start_port=8000, end_port=9000, max_attempts=100)

Find an available TCP port in the given range.

Parameters:

  • start_port (int) – The lower bound of the port range (inclusive).
  • end_port (int) – The upper bound of the port range (inclusive).
  • max_attempts (int) – Maximum number of attempts to find a free port.

Returns:

An available port number.

Return type:

int

Raises:

RuntimeError – If no available port is found after max_attempts.

Was this page helpful?