Skip to main content

Python class

LinearLoRA

LinearLoRA​

class max.nn.LinearLoRA(in_dim, out_dim, max_num_loras, max_lora_rank, dtype, device, has_lora_bias=False, name=None, quantization_encoding=None)

source

Bases: Module, SupportsLoRA

Applies a linear transformation and LoRA to input:

yl=(xAT)@BTy_l = (xA^T) @ B^T. y=(xWT+b)+yly = (xW^T + b) + y_l

linear_layer = LinearLoRA(
    in_dim=256,
    out_dim=128,
    max_lora_rank=16,
    max_num_loras=100,
    dtype=dtype.float32,
    device=DeviceRef.GPU(),
    has_bias=True,
    has_lora_bias=True,
    name="lora_linear"
)

lora_ids: TensorValue # shape: [max_num_loras,]
lora_ranks: TensorValue # shape: [max_num_loras,]
input_row_offsets: TensorValue
linear_layer.set_lora_batch_info(lora_ids, lora_ranks, input_row_offsets)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Parameters:

set_lora_batch_info()​

set_lora_batch_info(lora_ids, lora_ranks, lora_grouped_offsets, num_active_loras, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv)

source

Parameters:

Return type:

None