Python class
LoRAManager
LoRAManagerβ
class max.pipelines.lib.LoRAManager(config, base_model_path, base_dtype, n_heads, n_kv_heads, head_dim, zmq_endpoint_base)
Bases: object
Manages multiple LoRA models and buffers for the forward pass.
Applies multiple LoRA models to a set of base weights and manages the underlying buffers required for the forward pass.
Initializes the LoRAManager with a given base weight structure and maximum number of LoRA models.
-
Parameters:
-
- config (LoRAConfig) β The LoRA config.
- base_model_path (str) β The name/path of the base model.
- base_dtype (DType) β The base model dtype.
- n_heads (int) β The number of attention heads in the base model.
- n_kv_heads (int) β The number of key-value heads in the base model.
- head_dim (int) β The dimension of each attention head.
- zmq_endpoint_base (str) β The ZMQ endpoint base used to construct ZMQ lora request and response endpoints.
activate_adapter()β
activate_adapter(name)
Moves the specified LoRA adapter to GPU and marks it as active.
Useful for enabling a specific adapter for use in model inference.
manager.activate_adapter("my_adapter")get_lora_graph_inputs()β
get_lora_graph_inputs(context_batch, input_row_offsets, device)
Returns the LoRA graph inputs for the batch.
get_lora_manager()β
static get_lora_manager(pipeline)
Returns the LoRAManager from the pipeline if LoRA is enabled.
-
Parameters:
-
pipeline (Pipeline[PipelineInputsType, PipelineOutputType])
-
Return type:
-
LoRAManager | None
get_symbolic_inputs()β
get_symbolic_inputs(device_ref)
Returns the input symbols needed for the graph inputs.
-
Parameters:
-
device_ref (DeviceRef) β Symbolic device to be used for the symbols.
-
Returns:
-
The graph input symbols.
-
Return type:
init_weights()β
init_weights(model, state_dict)
Recursively collects leaf SupportsLoRA modules and inits their weights.
Inits their weights with the loaded LoRAs and adds them to the
state_dict.
Acquires the alias-able buffers for dynamic LoRA swapping.
Must be called to initialize the base model properly.
-
Parameters:
-
- model (Module) β The top-level Module.
- state_dict (dict[str, WeightData]) β Model state_dict to be loaded into model.
- device β The device the base model resides in.
-
Return type:
-
None
is_active_lora()β
is_active_lora(name)
Returns whether the given name is an active LoRA adapter.
is_lora()β
is_lora(name)
Returns whether the given name is a loaded LoRA adapter.
load_adapter()β
load_adapter(path)
Loads a single LoRA adapter from the given path and registers it under a unique name.
The path can include an explicit name using the format name=path. If no name is provided, the path itself is used as the name.
lora_id = manager.load_adapter("my_adapter=/path/to/lora")
lora_id = manager.load_adapter("/path/to/another_lora")-
Parameters:
-
path (str) β A string in the form name=path or just a file path. The adapter is expected to reside at that path.
-
Returns:
-
LoRAStatus indicating the result of the load operation.
-
Return type:
lorasβ
Returns the list of loaded LoRA adapter names.
process_lora_requests()β
process_lora_requests()
Checks for new LoRA requests and processes them.
-
Return type:
-
None
set_graph_info()β
set_graph_info(lora_ids, lora_ranks, lora_grouped_offsets, num_active_loras, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv)
Sets the lora batch info required for the forward-pass.
-
Parameters:
-
- lora_ids (TensorValue) β IDs of the LoRAs used in the batch.
- lora_ranks (TensorValue) β Ranks of the LoRAs used in the batch.
- lora_grouped_offsets (TensorValue) β Cumulative offsets for each LoRA group.
- num_active_loras (TensorValue) β Number of active LoRA adapters in the batch.
- lora_end_idx (TensorValue) β End index of LoRA token portion.
- batch_seq_len (TensorValue) β Total sequence length in the batch.
- lora_ids_kv (TensorValue) β LoRA IDs for KV cache (includes K and V portions).
- lora_grouped_offsets_kv (TensorValue) β Cumulative offsets for KV LoRA groups.
-
Return type:
-
None
sort_lora_batch()β
sort_lora_batch(context_batch)
Sorts the LoRA batch by LRU cache id.
unload_adapter()β
unload_adapter(name)
Unloads the specified LoRA adapter from the internal registry and frees its slot.
This function is used to release GPU or CPU memory by removing a LoRA model.
manager.unload_adapter("my_adapter")-
Parameters:
-
name (str) β The name of the LoRA adapter to unload.
-
Returns:
-
LoRAStatus indicating the result of the unload operation.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!