IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

LoRAManager

LoRAManager​

class max.pipelines.lib.LoRAManager(config, base_model_path, base_dtype, n_heads, n_kv_heads, head_dim)

source

Bases: object

Manages multiple LoRA models and buffers for the forward pass.

Applies multiple LoRA models to a set of base weights and manages the underlying buffers required for the forward pass.

Initializes the LoRAManager with a given base weight structure and maximum number of LoRA models.

Parameters:

  • config (LoRAConfig) – The LoRA config.
  • base_model_path (str) – The name/path of the base model.
  • base_dtype (DType) – The base model dtype.
  • n_heads (int) – The number of attention heads in the base model.
  • n_kv_heads (int) – The number of key-value heads in the base model.
  • head_dim (int) – The dimension of each attention head.

activate_adapter()​

activate_adapter(name)

source

Moves the specified LoRA adapter to GPU and marks it as active.

Useful for enabling a specific adapter for use in model inference.

manager.activate_adapter("my_adapter")

Parameters:

name (str) – The name of the LoRA adapter to activate.

Returns:

None

Raises:

KeyError – If the specified adapter does not exist in the registry.

Return type:

None

bind_graph_inputs()​

bind_graph_inputs(graph_inputs)

source

Wires the LoRA graph inputs into the model and returns the rest.

The LoRA inputs immediately follow the model’s head inputs, so this peels them off the front of graph_inputs, wires them into the LoRA layers, and returns the remaining (non-LoRA) inputs.

Parameters:

graph_inputs (Sequence[Value[Any]]) – The model’s graph inputs with its head inputs already removed.

Returns:

graph_inputs with the LoRA inputs removed.

Return type:

list[Value[Any]]

get_lora_graph_inputs()​

get_lora_graph_inputs(context_batch, input_row_offsets, device)

source

Returns the LoRA graph inputs for the batch.

Parameters:

  • context_batch (Sequence[TextGenerationContextType]) – The generation contexts for the batch.
  • input_row_offsets (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – The offsets for each sequence in the batch.
  • device (Device) – The device.

Return type:

tuple[Buffer, …]

get_lora_manager()​

static get_lora_manager(pipeline)

source

Returns the LoRAManager from the pipeline if LoRA is enabled.

Parameters:

pipeline (Pipeline[PipelineInputsType, PipelineOutputType])

Return type:

LoRAManager | None

get_symbolic_inputs()​

get_symbolic_inputs(device_ref)

source

Returns the input symbols needed for the graph inputs.

Parameters:

device_ref (DeviceRef) – Symbolic device to be used for the symbols.

Returns:

The graph input symbols, ordered to match LoRAInputs.

Return type:

list[TensorType]

init_weights()​

init_weights(model, state_dict)

source

Recursively collects leaf SupportsLoRA modules and inits their weights.

Inits their weights with the loaded LoRAs and adds them to the state_dict.

Acquires the alias-able buffers for dynamic LoRA swapping.

Must be called to initialize the base model properly.

Parameters:

  • model (Module) – The top-level Module.
  • state_dict (dict[str, WeightData]) – Model state_dict to be loaded into model.
  • device – The device the base model resides in.

Return type:

None

is_active_lora()​

is_active_lora(name)

source

Returns whether the given name is an active LoRA adapter.

Parameters:

name (str)

Return type:

bool

is_lora()​

is_lora(name)

source

Returns whether the given name is a loaded LoRA adapter.

Parameters:

name (str)

Return type:

bool

load_adapter()​

load_adapter(path)

source

Loads a single LoRA adapter from the given path and registers it under a unique name.

The path can include an explicit name using the format name=path. If no name is provided, the path itself is used as the name.

lora_id = manager.load_adapter("my_adapter=/path/to/lora")
lora_id = manager.load_adapter("/path/to/another_lora")

Parameters:

path (str) – A string in the form name=path or just a file path. The adapter is expected to reside at that path.

Returns:

LoRAStatus indicating the result of the load operation.

Return type:

LoRAStatus

loras​

property loras: list[str]

source

Returns the list of loaded LoRA adapter names.

set_graph_info()​

set_graph_info(lora_inputs)

source

Wires the LoRA batch info into the LoRA layers for the forward pass.

Parameters:

lora_inputs (Sequence[TensorValue]) – The LoRA graph-input tensors in LoRAInputs order.

Return type:

None

sort_lora_batch()​

sort_lora_batch(context_batch)

source

Sorts the LoRA batch by LRU cache id.

Parameters:

context_batch (list[TextGenerationContextType]) – The context batch to sort

Return type:

list[TextGenerationContextType]

unload_adapter()​

unload_adapter(name)

source

Unloads the specified LoRA adapter from the internal registry and frees its slot.

This function is used to release GPU or CPU memory by removing a LoRA model.

manager.unload_adapter("my_adapter")

Parameters:

name (str) – The name of the LoRA adapter to unload.

Returns:

LoRAStatus indicating the result of the unload operation.

Return type:

LoRAStatus