IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Write a weight adapter

Nearly all model architectures require a weight adapter. Write a weight adapter to convert checkpoint weights (most often published on Hugging Face) into the format expected by your MAX Module.

What a weight adapter does​

A weight adapter is a function you define in weight_adapters.py that converts checkpoint weights into a state dictionary that you can load into a MAX Module.

Depending on how the checkpoint differs from your implementation, a weight adapter can perform one or more of the following transformations:

  • Rename keys to match the parameter names expected by the Module. In many architectures, renaming keys is the only transformation required.
  • Cast dtypes to the types expected by the model.
  • Remove weights that exist only in the checkpoint and not in the Module.
  • Transform tensor layouts by reshaping, reordering, or repacking tensors into the format expected by the model.

After you implement a weight adapter, register it in the architecture's SupportedArchitecture definition in arch.py, keyed by the corresponding WeightsFormat:

weight_adapters={
    WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
    WeightsFormat.gguf: weight_adapters.convert_gguf_state_dict,
},

During serving, MAX identifies the checkpoint format, selects the corresponding adapter, and provides it to the pipeline model. The pipeline model applies the adapter to the loaded checkpoint and passes the resulting state dictionary to load_state_dict(), which loads the tensors into the model's Module hierarchy. See Model pipelines to learn more about implementing the pipeline model class.

Checkpoint and module weight formats​

Before you write a weight adapter, inspect both the checkpoint and your Module to understand what transformations you need to implement to make the checkpoint compatible.

To inspect the checkpoint, pass it to load_weights() and iterate over the entries to print each tensor's name, shape, and dtype:

from pathlib import Path
from max.graph.weights import load_weights

weights = load_weights([Path("/path/to/model.safetensors")])  # or .gguf
for name, weight in weights.items():
    data = weight.data()
    shape = list(int(d) for d in data.shape)
    print(f"{name}: {shape} {data.dtype}")

The output may look similar to the following:

model.embed_tokens.weight: [vocab_size, hidden_size] DType.bfloat16
model.layers.0.input_layernorm.weight: [hidden_size] DType.bfloat16
model.layers.0.self_attn.q_proj.weight: [hidden_size, hidden_size] DType.bfloat16
...

To inspect the target format your Module expects, call raw_state_dict() on your module instance (often named model). This returns each Weight in the Module, including its shape, dtype, and quantization encoding:

for name, weight in model.raw_state_dict().items():
    print(f"{name}: {weight.shape} {weight.dtype} {weight.quantization_encoding}")

The output may look similar to the following:

layers.0.fc1.weight: [ffn_size, hidden_size] DType.bfloat16 None
layers.0.fc2.weight: [hidden_size, ffn_size] DType.bfloat16 None
layers.0.attention.q_proj.weight: [hidden_size, hidden_size] DType.bfloat16 None
...

Name mapping​

The most common job of the weight adapter is to map checkpoint parameter names to the names expected by your MAX Module.

To do this, define a weight adapter function that:

  1. Accepts the checkpoint state dictionary.
  2. Iterates over the checkpoint state dictionary.
  3. Computes the parameter name to match your Module.
  4. Builds and returns a new state dictionary that maps the new parameter names to the original weight tensors.

In many cases, the mapping is straightforward because both Hugging Face and MAX derive parameter names from a module hierarchy. If the model architectures are similar, the adapter may only need to remove a common prefix or rename a small number of modules.

For example, many Hugging Face safetensors checkpoints store weights under a top-level model. prefix that doesn't exist in your MAX Module. You could write a weight adapter function called convert_safetensor_state_dict() to remove that prefix:

from max.graph.weights import WeightData, Weights

SAFETENSOR_MAPPING = {
    "model.": "",                            # Strip this string
    "self_attn.q_proj": "attention.q_proj",  # Rename this string
}


def convert_safetensor_state_dict(
    state_dict: dict[str, Weights], **unused_kwargs
) -> dict[str, WeightData]:
    new_state_dict: dict[str, WeightData] = {}
    for checkpoint_name, value in state_dict.items():
        max_name = checkpoint_name
        for before, after in SAFETENSOR_MAPPING.items():
            max_name = max_name.replace(before, after)
        new_state_dict[max_name] = value.data()
    return new_state_dict

Dtype casting​

Every weight in a Module has an expected dtype. You set that dtype when you construct the Module, matching what is required by the MAX kernel that consumes the weight. A checkpoint might store the same weight in a different dtype than the kernel expects.

When the dtypes differ, cast the weight in your weight adapter before you load it into the Module. If you don't, load_state_dict() raises a ValueError.

The following weight adapter casts the weights that need converting and leaves auxiliary tensors untouched:

from max.dtype import DType
from max.graph.weights import WeightData, Weights


def convert_safetensor_state_dict(
    state_dict: dict[str, Weights], **unused_kwargs
) -> dict[str, WeightData]:
    new_state_dict: dict[str, WeightData] = {}
    for checkpoint_name, value in state_dict.items():
        max_name = checkpoint_name
        for before, after in SAFETENSOR_MAPPING.items():  # defined earlier
            max_name = max_name.replace(before, after)
        new_state_dict[max_name] = value.data()

    # Cast weights to a dtype the kernel supports
    for key, weight_data in new_state_dict.items():
        if weight_data.dtype == DType.float16 and not (
            key.endswith("bias") or key.endswith("scales")
        ):
            new_state_dict[key] = weight_data.astype(DType.bfloat16)

    return new_state_dict

Perform the conversion in the weight adapter for tensor-level dtype casts. Defer layer-level encoding decisions to QuantFormat and QuantConfig instead.

Remove unused weights​

A checkpoint and your Module don't always contain the same set of weights. A mismatch can occur in two ways:

  • Extra weights. The checkpoint contains weights that your Module doesn't declare. Remove these in the weight adapter.
  • Missing weights. Your Module declares weights that aren't present in the checkpoint. These values typically come from somewhere other than the checkpoint, so you usually don't handle them in the weight adapter.

For example, FP8 checkpoints often include .k_scale and .v_scale tensors used to quantize the key and value caches. MAX computes these scales at runtime, so you should remove them from the checkpoint:

from max.graph.weights import WeightData, Weights

UNUSED_SUFFIXES = (".k_scale", ".v_scale")


def convert_safetensor_state_dict(
    state_dict: dict[str, Weights], **unused_kwargs
) -> dict[str, WeightData]:
    new_state_dict: dict[str, WeightData] = {}
    for checkpoint_name, value in state_dict.items():
        max_name = checkpoint_name
        for before, after in SAFETENSOR_MAPPING.items():  # defined earlier
            max_name = max_name.replace(before, after)
        new_state_dict[max_name] = value.data()

    # Drop FP8 KV-cache scales
    for key in list(new_state_dict):
        if key.endswith(UNUSED_SUFFIXES):
            del new_state_dict[key]

    return new_state_dict

Tensor transformations​

Transform tensor values only when the checkpoint representation differs from the format expected by a MAX kernel. When possible, adapt your Module instead of modifying checkpoint data. Tensor transformations encode assumptions about a specific kernel implementation, and you may need to update them if the kernel interface changes.

Common cases include:

  • Partial-RoPE permutation. Some models use a partial-RoPE layout where only a subset of attention dimensions participate in rotary embeddings. If the checkpoint layout differs from the layout expected by the MAX kernel, the adapter must permute the query and key projection weights before loading.
  • Quantization-specific transformations. Quantized checkpoints may store weights in a layout that differs from the dequantization kernel's expected format. For example, MXFP4 checkpoints require weight repacking, and GPTQ checkpoints store a g_idx permutation that you must convert to the inverse perm_idx representation expected by the kernel.

This example permutes the query and key projection weights so their layout matches what the MAX RoPE kernel expects:

import numpy as np
from max.graph.weights import WeightData, Weights
from transformers import AutoConfig


def convert_safetensor_state_dict(
    state_dict: dict[str, Weights],
    huggingface_config: AutoConfig,
    **unused_kwargs,
) -> dict[str, WeightData]:
    num_heads = huggingface_config.num_attention_heads
    head_dim = huggingface_config.hidden_size // num_heads
    num_kv_heads = getattr(
        huggingface_config, "num_key_value_heads", num_heads
    )
    rope_perm = build_rope_perm(head_dim)  # architecture-specific helper

    new_state_dict: dict[str, WeightData] = {}
    for checkpoint_name, value in state_dict.items():
        max_name = checkpoint_name
        for before, after in SAFETENSOR_MAPPING.items():  # defined earlier
            max_name = max_name.replace(before, after)
        new_state_dict[max_name] = value.data()

    # Permute weights
    for key, weight_data in new_state_dict.items():
        if key.endswith("q_proj.weight"):
            heads = num_heads
        elif key.endswith("k_proj.weight"):
            heads = num_kv_heads
        else:
            continue
        arr = np.from_dlpack(weight_data.data)
        arr = arr.reshape(heads, head_dim, -1)
        arr = arr[:, rope_perm, :]
        arr = np.ascontiguousarray(arr.reshape(heads * head_dim, -1))
        new_state_dict[key] = WeightData.from_numpy(arr, key)

    return new_state_dict

Correctness checks​

Compare the keys produced by the weight adapter with the keys your Module expects:

from pathlib import Path

from max.graph.weights import load_weights
from transformers import AutoConfig

from my_arch.weight_adapters import convert_safetensor_state_dict

weights = load_weights([Path("/path/to/model.safetensors")])
hf_config = AutoConfig.from_pretrained("org/model-id")

adapted = convert_safetensor_state_dict(
    dict(weights.items()),
    huggingface_config=hf_config,
)
for name in sorted(adapted):
    print(name)

Compare the output against your Module's weights. Intentional differences may include:

  • Tensors removed because MAX computes the value at runtime.
  • Parameters omitted because MAX derives them from other weights.
  • Format-specific tensors that your Module doesn't use.

Next steps​

Once your adapter loads successfully, complete the remaining model bring-up steps:

For the complete workflow, see Model bring-up workflow.

Was this page helpful?