For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Write a weight adapter
Nearly all model architectures require a weight adapter. Write a weight adapter
to convert checkpoint weights (most often published on Hugging Face) into the
format expected by your MAX
Module.
What a weight adapter doesβ
A weight adapter is a function you define in weight_adapters.py that converts
checkpoint weights into a state dictionary that you can load into a MAX
Module.
Depending on how the checkpoint differs from your implementation, a weight adapter can perform one or more of the following transformations:
- Rename keys to match the parameter names expected by the
Module. In many architectures, renaming keys is the only transformation required. - Cast dtypes to the types expected by the model.
- Remove weights that exist only in the checkpoint and not in the
Module. - Transform tensor layouts by reshaping, reordering, or repacking tensors into the format expected by the model.
After you implement a weight adapter, register it in the architecture's
SupportedArchitecture
definition in arch.py, keyed by the corresponding
WeightsFormat:
weight_adapters={
WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict,
WeightsFormat.gguf: weight_adapters.convert_gguf_state_dict,
},During serving, MAX identifies the checkpoint format, selects the corresponding
adapter, and provides it to the pipeline model. The pipeline model applies the
adapter to the loaded checkpoint and passes the resulting state dictionary to
load_state_dict(),
which loads the tensors into the model's Module hierarchy. See
Model pipelines to learn more about implementing the
pipeline model class.
Checkpoint and module weight formatsβ
Before you write a weight adapter, inspect both the checkpoint and your Module
to understand what transformations you need to implement to make the checkpoint
compatible.
Print the checkpointβ
To inspect the checkpoint, pass it to
load_weights()
and iterate over the entries to print each tensor's name, shape, and dtype:
from pathlib import Path
from max.graph.weights import load_weights
weights = load_weights([Path("/path/to/model.safetensors")]) # or .gguf
for name, weight in weights.items():
data = weight.data()
shape = list(int(d) for d in data.shape)
print(f"{name}: {shape} {data.dtype}")The output may look similar to the following:
model.embed_tokens.weight: [vocab_size, hidden_size] DType.bfloat16
model.layers.0.input_layernorm.weight: [hidden_size] DType.bfloat16
model.layers.0.self_attn.q_proj.weight: [hidden_size, hidden_size] DType.bfloat16
...Print the module state dictionaryβ
To inspect the target format your Module expects, call
raw_state_dict()
on your module instance (often named model). This returns each
Weight in the Module,
including its shape, dtype, and quantization encoding:
for name, weight in model.raw_state_dict().items():
print(f"{name}: {weight.shape} {weight.dtype} {weight.quantization_encoding}")The output may look similar to the following:
layers.0.fc1.weight: [ffn_size, hidden_size] DType.bfloat16 None
layers.0.fc2.weight: [hidden_size, ffn_size] DType.bfloat16 None
layers.0.attention.q_proj.weight: [hidden_size, hidden_size] DType.bfloat16 None
...Name mappingβ
The most common job of the weight adapter is to map checkpoint parameter names
to the names expected by your MAX Module.
To do this, define a weight adapter function that:
- Accepts the checkpoint state dictionary.
- Iterates over the checkpoint state dictionary.
- Computes the parameter name to match your
Module. - Builds and returns a new state dictionary that maps the new parameter names to the original weight tensors.
In many cases, the mapping is straightforward because both Hugging Face and MAX derive parameter names from a module hierarchy. If the model architectures are similar, the adapter may only need to remove a common prefix or rename a small number of modules.
For example, many Hugging Face safetensors checkpoints store weights under a
top-level model. prefix that doesn't exist in your MAX Module. You could
write a weight adapter function called convert_safetensor_state_dict() to
remove that prefix:
from max.graph.weights import WeightData, Weights
SAFETENSOR_MAPPING = {
"model.": "", # Strip this string
"self_attn.q_proj": "attention.q_proj", # Rename this string
}
def convert_safetensor_state_dict(
state_dict: dict[str, Weights], **unused_kwargs
) -> dict[str, WeightData]:
new_state_dict: dict[str, WeightData] = {}
for checkpoint_name, value in state_dict.items():
max_name = checkpoint_name
for before, after in SAFETENSOR_MAPPING.items():
max_name = max_name.replace(before, after)
new_state_dict[max_name] = value.data()
return new_state_dictDtype castingβ
Every weight in a Module has an expected dtype. You set that dtype when you
construct the Module, matching what is required by the MAX kernel that
consumes the weight. A checkpoint might store the same weight in a different
dtype than the kernel expects.
When the dtypes differ, cast the weight in your weight adapter before you load
it into the Module. If you don't, load_state_dict() raises a ValueError.
The following weight adapter casts the weights that need converting and leaves auxiliary tensors untouched:
from max.dtype import DType
from max.graph.weights import WeightData, Weights
def convert_safetensor_state_dict(
state_dict: dict[str, Weights], **unused_kwargs
) -> dict[str, WeightData]:
new_state_dict: dict[str, WeightData] = {}
for checkpoint_name, value in state_dict.items():
max_name = checkpoint_name
for before, after in SAFETENSOR_MAPPING.items(): # defined earlier
max_name = max_name.replace(before, after)
new_state_dict[max_name] = value.data()
# Cast weights to a dtype the kernel supports
for key, weight_data in new_state_dict.items():
if weight_data.dtype == DType.float16 and not (
key.endswith("bias") or key.endswith("scales")
):
new_state_dict[key] = weight_data.astype(DType.bfloat16)
return new_state_dictPerform the conversion in the weight adapter for tensor-level dtype casts. Defer
layer-level encoding decisions to
QuantFormat and
QuantConfig instead.
Remove unused weightsβ
A checkpoint and your Module don't always contain the same set of weights. A
mismatch can occur in two ways:
- Extra weights. The checkpoint contains weights that your
Moduledoesn't declare. Remove these in the weight adapter. - Missing weights. Your
Moduledeclares weights that aren't present in the checkpoint. These values typically come from somewhere other than the checkpoint, so you usually don't handle them in the weight adapter.
For example, FP8 checkpoints often include .k_scale and .v_scale tensors
used to quantize the key and value caches. MAX computes these scales at runtime,
so you should remove them from the checkpoint:
from max.graph.weights import WeightData, Weights
UNUSED_SUFFIXES = (".k_scale", ".v_scale")
def convert_safetensor_state_dict(
state_dict: dict[str, Weights], **unused_kwargs
) -> dict[str, WeightData]:
new_state_dict: dict[str, WeightData] = {}
for checkpoint_name, value in state_dict.items():
max_name = checkpoint_name
for before, after in SAFETENSOR_MAPPING.items(): # defined earlier
max_name = max_name.replace(before, after)
new_state_dict[max_name] = value.data()
# Drop FP8 KV-cache scales
for key in list(new_state_dict):
if key.endswith(UNUSED_SUFFIXES):
del new_state_dict[key]
return new_state_dictTensor transformationsβ
Transform tensor values only when the checkpoint representation differs from the
format expected by a MAX kernel. When possible, adapt your Module instead of
modifying checkpoint data. Tensor transformations encode assumptions about a
specific kernel implementation, and you may need to update them if the kernel
interface changes.
Common cases include:
- Partial-RoPE permutation. Some models use a partial-RoPE layout where only a subset of attention dimensions participate in rotary embeddings. If the checkpoint layout differs from the layout expected by the MAX kernel, the adapter must permute the query and key projection weights before loading.
- Quantization-specific transformations. Quantized checkpoints may store weights
in a layout that differs from the dequantization kernel's expected format. For
example, MXFP4 checkpoints require weight repacking, and GPTQ checkpoints
store a
g_idxpermutation that you must convert to the inverseperm_idxrepresentation expected by the kernel.
This example permutes the query and key projection weights so their layout matches what the MAX RoPE kernel expects:
import numpy as np
from max.graph.weights import WeightData, Weights
from transformers import AutoConfig
def convert_safetensor_state_dict(
state_dict: dict[str, Weights],
huggingface_config: AutoConfig,
**unused_kwargs,
) -> dict[str, WeightData]:
num_heads = huggingface_config.num_attention_heads
head_dim = huggingface_config.hidden_size // num_heads
num_kv_heads = getattr(
huggingface_config, "num_key_value_heads", num_heads
)
rope_perm = build_rope_perm(head_dim) # architecture-specific helper
new_state_dict: dict[str, WeightData] = {}
for checkpoint_name, value in state_dict.items():
max_name = checkpoint_name
for before, after in SAFETENSOR_MAPPING.items(): # defined earlier
max_name = max_name.replace(before, after)
new_state_dict[max_name] = value.data()
# Permute weights
for key, weight_data in new_state_dict.items():
if key.endswith("q_proj.weight"):
heads = num_heads
elif key.endswith("k_proj.weight"):
heads = num_kv_heads
else:
continue
arr = np.from_dlpack(weight_data.data)
arr = arr.reshape(heads, head_dim, -1)
arr = arr[:, rope_perm, :]
arr = np.ascontiguousarray(arr.reshape(heads * head_dim, -1))
new_state_dict[key] = WeightData.from_numpy(arr, key)
return new_state_dictCorrectness checksβ
Compare the keys produced by the weight adapter with the keys your Module
expects:
from pathlib import Path
from max.graph.weights import load_weights
from transformers import AutoConfig
from my_arch.weight_adapters import convert_safetensor_state_dict
weights = load_weights([Path("/path/to/model.safetensors")])
hf_config = AutoConfig.from_pretrained("org/model-id")
adapted = convert_safetensor_state_dict(
dict(weights.items()),
huggingface_config=hf_config,
)
for name in sorted(adapted):
print(name)Compare the output against your Module's weights.
Intentional differences may include:
- Tensors removed because MAX computes the value at runtime.
- Parameters omitted because MAX derives them from other weights.
- Format-specific tensors that your
Moduledoesn't use.
Next stepsβ
Once your adapter loads successfully, complete the remaining model bring-up steps:
- Register the architecture.
Add the adapter to your
SupportedArchitecturedefinition so MAX can load checkpoints in the supported formats. - Implement the pipeline model. The pipeline model class connects weight loading, graph construction, compilation, and execution.
- Configure quantization support. If your checkpoint
uses a quantized format, implement the corresponding
QuantConfigandQuantFormatintegration.
For the complete workflow, see Model bring-up workflow.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!