Python function
load_weights
load_weights()
max.graph.weights.load_weights(paths)
Loads neural network weights from checkpoint files.
Automatically detects checkpoint formats based on file extensions and returns
the appropriate Weights implementation. Supported formats:
.safetensors(Safetensors).gguf(GGUF)
The following example shows how to load weights from a Safetensors file:
from pathlib import Path
from max.graph.weights import load_weights
sharded_paths = [
Path("model-00001-of-00003.safetensors"),
Path("model-00002-of-00003.safetensors"),
Path("model-00003-of-00003.safetensors")
]
weights = load_weights(sharded_paths)
layer_weight = weights.model.layers[23].mlp.gate_proj.weight.allocate(
dtype=DType.float32,
shape=[4096, 14336],
device=DeviceRef.GPU(0)
)-
Parameters:
-
paths (list[Path]) – List of
pathlib.Pathobjects pointing to checkpoint files. For multi-file checkpoints (for example, sharded Safetensors), provide all file paths in the list. For single-file checkpoints, provide a list with one path. -
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!