Python class
LoRAManager
LoRAManager
class max.pipelines.lib.LoRAManager(config, base_model_path, base_dtype, n_heads, n_kv_heads, head_dim, zmq_endpoint_base)
Bases: object
Manages multiple LoRA models and buffers for the forward pass.
Applies multiple LoRA models to a set of base weights and manages the underlying buffers required for the forward pass.
Initializes the LoRAManager with a given base weight structure and maximum number of LoRA models.
-
Parameters:
-
- config (LoRAConfig) – The LoRA config.
- base_model_path (str) – The name/path of the base model.
- base_dtype (DType) – The base model dtype.
- n_heads (int) – The number of attention heads in the base model.
- n_kv_heads (int) – The number of key-value heads in the base model.
- head_dim (int) – The dimension of each attention head.
- zmq_endpoint_base (str) – The ZMQ endpoint base used to construct ZMQ lora request and response endpoints.
activate_adapter()
activate_adapter(name)
Moves the specified LoRA adapter to GPU and marks it as active.
Useful for enabling a specific adapter for use in model inference.
manager.activate_adapter("my_adapter")get_lora_graph_inputs()
get_lora_graph_inputs(context_batch, input_row_offsets, device)
Returns the LoRA graph inputs for the batch.
get_lora_manager()
static get_lora_manager(pipeline)
Returns the LoRAManager from the pipeline if LoRA is enabled.
-
Parameters:
-
pipeline (Pipeline[PipelineInputsType, PipelineOutputType])
-
Return type:
-
LoRAManager | None
get_symbolic_inputs()
get_symbolic_inputs(device_ref)
Returns the input symbols needed for the graph inputs.
-
Parameters:
-
device_ref (DeviceRef) – Symbolic device to be used for the symbols.
-
Returns:
-
The graph input symbols.
-
Return type:
init_weights()
init_weights(model, state_dict)
Recursively collects leaf SupportsLoRA modules and inits their weights.
Inits their weights with the loaded LoRAs and adds them to the
state_dict.
Acquires the alias-able buffers for dynamic LoRA swapping.
Must be called to initialize the base model properly.
-
Parameters:
-
- model (Module) – The top-level Module.
- state_dict (dict[str, WeightData]) – Model state_dict to be loaded into model.
- device – The device the base model resides in.
-
Return type:
-
None
is_active_lora()
is_active_lora(name)
Returns whether the given name is an active LoRA adapter.
is_lora()
is_lora(name)
Returns whether the given name is a loaded LoRA adapter.
load_adapter()
load_adapter(path)
Loads a single LoRA adapter from the given path and registers it under a unique name.
The path can include an explicit name using the format name=path. If no name is provided, the path itself is used as the name.
lora_id = manager.load_adapter("my_adapter=/path/to/lora")
lora_id = manager.load_adapter("/path/to/another_lora")-
Parameters:
-
path (str) – A string in the form name=path or just a file path. The adapter is expected to reside at that path.
-
Returns:
-
LoRAStatus indicating the result of the load operation.
-
Return type:
loras
Returns the list of loaded LoRA adapter names.
process_lora_requests()
process_lora_requests()
Checks for new LoRA requests and processes them.
-
Return type:
-
None
set_graph_info()
set_graph_info(lora_ids, lora_ranks, lora_grouped_offsets, num_active_loras, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv)
Sets the lora batch info required for the forward-pass.
-
Parameters:
-
- lora_ids (TensorValue) – IDs of the LoRAs used in the batch.
- lora_ranks (TensorValue) – Ranks of the LoRAs used in the batch.
- lora_grouped_offsets (TensorValue) – Cumulative offsets for each LoRA group.
- num_active_loras (TensorValue) – Number of active LoRA adapters in the batch.
- lora_end_idx (TensorValue) – End index of LoRA token portion.
- batch_seq_len (TensorValue) – Total sequence length in the batch.
- lora_ids_kv (TensorValue) – LoRA IDs for KV cache (includes K and V portions).
- lora_grouped_offsets_kv (TensorValue) – Cumulative offsets for KV LoRA groups.
-
Return type:
-
None
sort_lora_batch()
sort_lora_batch(context_batch)
Sorts the LoRA batch by LRU cache id.
unload_adapter()
unload_adapter(name)
Unloads the specified LoRA adapter from the internal registry and frees its slot.
This function is used to release GPU or CPU memory by removing a LoRA model.
manager.unload_adapter("my_adapter")-
Parameters:
-
name (str) – The name of the LoRA adapter to unload.
-
Returns:
-
LoRAStatus indicating the result of the unload operation.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!