Skip to main content

Python function

load_weights

load_weights()

max.graph.weights.load_weights(paths)

source

Loads neural network weights from checkpoint files.

Automatically detects checkpoint formats based on file extensions and returns the appropriate Weights implementation. Supported formats:

  • .safetensors (Safetensors)
  • .gguf (GGUF)

The following example shows how to load weights from a Safetensors file:

from pathlib import Path
from max.graph.weights import load_weights

sharded_paths = [
    Path("model-00001-of-00003.safetensors"),
    Path("model-00002-of-00003.safetensors"),
    Path("model-00003-of-00003.safetensors")
]
weights = load_weights(sharded_paths)
layer_weight = weights.model.layers[23].mlp.gate_proj.weight.allocate(
    dtype=DType.float32,
    shape=[4096, 14336],
    device=DeviceRef.GPU(0)
)

Parameters:

paths (list[Path]) – List of pathlib.Path objects pointing to checkpoint files. For multi-file checkpoints (for example, sharded Safetensors), provide all file paths in the list. For single-file checkpoints, provide a list with one path.

Return type:

Weights