IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

LoRAManager

LoRAManager​

class max.pipelines.lora.LoRAManager(config, base_model_path, base_dtype, n_heads, n_kv_heads, head_dim)

source

Bases: object

Manages multiple LoRA models and buffers for the forward pass.

Applies multiple LoRA models to a set of base weights and manages the underlying buffers required for the forward pass.

Initializes the LoRAManager with a given base weight structure and maximum number of LoRA models.

Parameters:

  • config (LoRAConfig) – The LoRA config.
  • base_model_path (str) – The name/path of the base model.
  • base_dtype (DType) – The base model dtype.
  • n_heads (int) – The number of attention heads in the base model.
  • n_kv_heads (int) – The number of key-value heads in the base model.
  • head_dim (int) – The dimension of each attention head.

activate_adapter()​

activate_adapter(name)

source

Moves the specified LoRA adapter to GPU and marks it as active.

Useful for enabling a specific adapter for use in model inference.

manager.activate_adapter("my_adapter")

Parameters:

name (str) – The name of the LoRA adapter to activate.

Returns:

None

Raises:

KeyError – If the specified adapter does not exist in the registry.

Return type:

None

get_lora_graph_inputs()​

get_lora_graph_inputs(context_batch, input_row_offsets, device)

source

Returns the LoRA graph inputs for the batch.

Parameters:

  • context_batch (Sequence[TextGenerationContextType]) – The generation contexts for the batch.
  • input_row_offsets (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – The offsets for each sequence in the batch.
  • device (Device) – The device.

Return type:

tuple[Buffer, …]

get_lora_manager()​

static get_lora_manager(pipeline)

source

Returns the LoRAManager from the pipeline if LoRA is enabled.

Parameters:

pipeline (Pipeline[PipelineInputsType, PipelineOutputType])

Return type:

LoRAManager | None

get_symbolic_inputs()​

get_symbolic_inputs(device_ref)

source

Returns the input symbols needed for the graph inputs.

Parameters:

device_ref (DeviceRef) – Symbolic device to be used for the symbols.

Returns:

The graph input symbols.

Return type:

list[TensorType]

init_weights()​

init_weights(model, state_dict)

source

Recursively collects leaf SupportsLoRA modules and inits their weights.

Inits their weights with the loaded LoRAs and adds them to the state_dict.

Acquires the alias-able buffers for dynamic LoRA swapping.

Must be called to initialize the base model properly.

Parameters:

  • model (Module) – The top-level Module.
  • state_dict (dict[str, WeightData]) – Model state_dict to be loaded into model.
  • device – The device the base model resides in.

Return type:

None

is_active_lora()​

is_active_lora(name)

source

Returns whether the given name is an active LoRA adapter.

Parameters:

name (str)

Return type:

bool

is_lora()​

is_lora(name)

source

Returns whether the given name is a loaded LoRA adapter.

Parameters:

name (str)

Return type:

bool

load_adapter()​

load_adapter(path)

source

Loads a single LoRA adapter from the given path and registers it under a unique name.

The path can include an explicit name using the format name=path. If no name is provided, the path itself is used as the name.

lora_id = manager.load_adapter("my_adapter=/path/to/lora")
lora_id = manager.load_adapter("/path/to/another_lora")

Parameters:

path (str) – A string in the form name=path or just a file path. The adapter is expected to reside at that path.

Returns:

LoRAStatus indicating the result of the load operation.

Return type:

LoRAStatus

loras​

property loras: list[str]

source

Returns the list of loaded LoRA adapter names.

set_graph_info()​

set_graph_info(lora_ids, lora_ranks, lora_grouped_offsets, num_active_loras, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv)

source

Sets the lora batch info required for the forward-pass.

Parameters:

  • lora_ids (TensorValue) – IDs of the LoRAs used in the batch.
  • lora_ranks (TensorValue) – Ranks of the LoRAs used in the batch.
  • lora_grouped_offsets (TensorValue) – Cumulative offsets for each LoRA group.
  • num_active_loras (TensorValue) – Number of active LoRA adapters in the batch.
  • lora_end_idx (TensorValue) – End index of LoRA token portion.
  • batch_seq_len (TensorValue) – Total sequence length in the batch.
  • lora_ids_kv (TensorValue) – LoRA IDs for KV cache (includes K and V portions).
  • lora_grouped_offsets_kv (TensorValue) – Cumulative offsets for KV LoRA groups.

Return type:

None

sort_lora_batch()​

sort_lora_batch(context_batch)

source

Sorts the LoRA batch by LRU cache id.

Parameters:

context_batch (list[TextGenerationContextType]) – The context batch to sort

Return type:

list[TextGenerationContextType]

unload_adapter()​

unload_adapter(name)

source

Unloads the specified LoRA adapter from the internal registry and frees its slot.

This function is used to release GPU or CPU memory by removing a LoRA model.

manager.unload_adapter("my_adapter")

Parameters:

name (str) – The name of the LoRA adapter to unload.

Returns:

LoRAStatus indicating the result of the unload operation.

Return type:

LoRAStatus