Skip to main content

Python class

LoRAManager

LoRAManager

class max.pipelines.lib.LoRAManager(config, base_model_path, base_dtype, n_heads, n_kv_heads, head_dim, zmq_endpoint_base)

source

Bases: object

Manages multiple LoRA models and buffers for the forward pass.

Applies multiple LoRA models to a set of base weights and manages the underlying buffers required for the forward pass.

Initializes the LoRAManager with a given base weight structure and maximum number of LoRA models.

Parameters:

  • config (LoRAConfig) – The LoRA config.
  • base_model_path (str) – The name/path of the base model.
  • base_dtype (DType) – The base model dtype.
  • n_heads (int) – The number of attention heads in the base model.
  • n_kv_heads (int) – The number of key-value heads in the base model.
  • head_dim (int) – The dimension of each attention head.
  • zmq_endpoint_base (str) – The ZMQ endpoint base used to construct ZMQ lora request and response endpoints.

activate_adapter()

activate_adapter(name)

source

Moves the specified LoRA adapter to GPU and marks it as active.

Useful for enabling a specific adapter for use in model inference.

manager.activate_adapter("my_adapter")

Parameters:

name (str) – The name of the LoRA adapter to activate.

Returns:

None

Raises:

KeyError – If the specified adapter does not exist in the registry.

Return type:

None

get_lora_graph_inputs()

get_lora_graph_inputs(context_batch, input_row_offsets, device)

source

Returns the LoRA graph inputs for the batch.

Parameters:

  • context_batch (Sequence[TextGenerationContextType]) – The generation contexts for the batch.
  • input_row_offsets (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – The offsets for each sequence in the batch.
  • device (Device) – The device.

Return type:

tuple[Buffer, …]

get_lora_manager()

static get_lora_manager(pipeline)

source

Returns the LoRAManager from the pipeline if LoRA is enabled.

Parameters:

pipeline (Pipeline[PipelineInputsType, PipelineOutputType])

Return type:

LoRAManager | None

get_symbolic_inputs()

get_symbolic_inputs(device_ref)

source

Returns the input symbols needed for the graph inputs.

Parameters:

device_ref (DeviceRef) – Symbolic device to be used for the symbols.

Returns:

The graph input symbols.

Return type:

list[TensorType]

init_weights()

init_weights(model, state_dict)

source

Recursively collects leaf SupportsLoRA modules and inits their weights.

Inits their weights with the loaded LoRAs and adds them to the state_dict.

Acquires the alias-able buffers for dynamic LoRA swapping.

Must be called to initialize the base model properly.

Parameters:

  • model (Module) – The top-level Module.
  • state_dict (dict[str, WeightData]) – Model state_dict to be loaded into model.
  • device – The device the base model resides in.

Return type:

None

is_active_lora()

is_active_lora(name)

source

Returns whether the given name is an active LoRA adapter.

Parameters:

name (str)

Return type:

bool

is_lora()

is_lora(name)

source

Returns whether the given name is a loaded LoRA adapter.

Parameters:

name (str)

Return type:

bool

load_adapter()

load_adapter(path)

source

Loads a single LoRA adapter from the given path and registers it under a unique name.

The path can include an explicit name using the format name=path. If no name is provided, the path itself is used as the name.

lora_id = manager.load_adapter("my_adapter=/path/to/lora")
lora_id = manager.load_adapter("/path/to/another_lora")

Parameters:

path (str) – A string in the form name=path or just a file path. The adapter is expected to reside at that path.

Returns:

LoRAStatus indicating the result of the load operation.

Return type:

LoRAStatus

loras

property loras: list[str]

source

Returns the list of loaded LoRA adapter names.

process_lora_requests()

process_lora_requests()

source

Checks for new LoRA requests and processes them.

Return type:

None

set_graph_info()

set_graph_info(lora_ids, lora_ranks, lora_grouped_offsets, num_active_loras, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv)

source

Sets the lora batch info required for the forward-pass.

Parameters:

  • lora_ids (TensorValue) – IDs of the LoRAs used in the batch.
  • lora_ranks (TensorValue) – Ranks of the LoRAs used in the batch.
  • lora_grouped_offsets (TensorValue) – Cumulative offsets for each LoRA group.
  • num_active_loras (TensorValue) – Number of active LoRA adapters in the batch.
  • lora_end_idx (TensorValue) – End index of LoRA token portion.
  • batch_seq_len (TensorValue) – Total sequence length in the batch.
  • lora_ids_kv (TensorValue) – LoRA IDs for KV cache (includes K and V portions).
  • lora_grouped_offsets_kv (TensorValue) – Cumulative offsets for KV LoRA groups.

Return type:

None

sort_lora_batch()

sort_lora_batch(context_batch)

source

Sorts the LoRA batch by LRU cache id.

Parameters:

context_batch (list[TextGenerationContextType]) – The context batch to sort

Return type:

list[TextGenerationContextType]

unload_adapter()

unload_adapter(name)

source

Unloads the specified LoRA adapter from the internal registry and frees its slot.

This function is used to release GPU or CPU memory by removing a LoRA model.

manager.unload_adapter("my_adapter")

Parameters:

name (str) – The name of the LoRA adapter to unload.

Returns:

LoRAStatus indicating the result of the unload operation.

Return type:

LoRAStatus