Skip to main content

Python module

lora

AttentionWithRopeAndLoRA

class max.nn.lora.AttentionWithRopeAndLoRA(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, max_lora_rank, max_num_loras, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None)

Initializes the LoRA-enabled attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype (DType) – DType of the QKV and output projection weights.
  • devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use TensorParallelAttentionWithRope to use all devices during attention computation.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • stacked_qkv (bool) – Whether the weights are stacked together.
  • scale (float | None) – Value used to scale the results of the attention output.
  • has_bias (bool) – Whether to use an attention bias.
  • clip_qkv (float | None) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]
  • max_lora_rank (int)
  • max_num_loras (int)
  • float8_config (Float8Config | None)

rope

rope: RotaryEmbedding

LinearLoRA

class max.nn.lora.LinearLoRA(in_dim, out_dim, max_num_loras, max_lora_rank, dtype, device, has_lora_bias=False, name=None, quantization_encoding=None)

Applies a linear transformation and LoRA to input:

yl=(xAT)@BTy_l = (xA^T) @ B^T. y=(xWT+b)+yly = (xW^T + b) + y_l

Example:

linear_layer = LinearLoRA(
    in_dim=256,
    out_dim=128,
    max_lora_rank=16,
    max_num_loras=100,
    dtype=dtype.float32,
    device=DeviceRef.GPU(),
    has_bias=True,
    has_lora_bias=True,
    name="lora_linear"
)

lora_ids: TensorValue # shape: [max_num_loras,]
lora_ranks: TensorValue # shape: [max_num_loras,]
input_row_offsets: TensorValue
linear_layer.set_lora_batch_info(lora_ids, lora_ranks, input_row_offsets)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Parameters:

set_lora_batch_info()

set_lora_batch_info(lora_ids, lora_ranks, lora_grouped_offsets, num_active_loras, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv)

Parameters:

Return type:

None

SupportsLoRA

class max.nn.lora.SupportsLoRA(*args, **kwargs)

Base class for supporting LoRA functionality in Modules

set_lora_batch_info()

set_lora_batch_info(lora_ids, lora_ranks, lora_grouped_offsets, num_active_loras, lora_end_idx, batch_seq_len, lora_ids_kv, lora_grouped_offsets_kv)

Parameters:

Return type:

None

Was this page helpful?