Python module
weights
Weights are the learned parameters that store a neural network’s knowledge. They’re multi-dimensional arrays (tensors) of numerical values that determine how the model transforms inputs into outputs. These weights contain all the information needed for a model to perform its task - whether that’s text generation, image classification, or any other capability.
GGUFWeights
class max.graph.weights.GGUFWeights(source, tensors=None, prefix='', allocated=None)
Implementation for loading weights from GGUF (GPT-Generated Unified Format) files.
GGUFWeights
provides an interface to load model weights from GGUF files,
which are optimized for quantized large language models. GGUF is the
successor to GGML format and is commonly used in the llama.cpp
ecosystem
for efficient storage and loading of quantized models.
from pathlib import Path
from max.graph.weights import GGUFWeights
from max.dtype import DType
from max.graph.quantization import QuantizationEncoding
gguf_path = Path("model-q4_k.gguf")
weights = GGUFWeights(gguf_path)
# Check if a weight exists
if weights.model.layers[0].attention.wq.exists():
# Allocate quantized attention weight
wq_weight = weights.model.layers[0].attention.wq.allocate(
dtype=DType.uint8, # GGUF quantized weights use uint8
device=DeviceRef.CPU()
)
# Access weight data with quantization info
weight_data = weights.model.layers[0].attention.wq.data()
print(f"Quantization: {weight_data.quantization_encoding}")
print(f"Shape: {weight_data.shape}")
# Allocate with quantization validation
ffn_weight = weights.model.layers[0].feed_forward.w1.allocate(
quantization_encoding=QuantizationEncoding.Q4_K,
device=DeviceRef.GPU(0)
)
# Iterate through all weights in a layer
for name, weight in weights.model.layers[0].items():
if weight.exists():
print(f"Found weight: {name}")
from pathlib import Path
from max.graph.weights import GGUFWeights
from max.dtype import DType
from max.graph.quantization import QuantizationEncoding
gguf_path = Path("model-q4_k.gguf")
weights = GGUFWeights(gguf_path)
# Check if a weight exists
if weights.model.layers[0].attention.wq.exists():
# Allocate quantized attention weight
wq_weight = weights.model.layers[0].attention.wq.allocate(
dtype=DType.uint8, # GGUF quantized weights use uint8
device=DeviceRef.CPU()
)
# Access weight data with quantization info
weight_data = weights.model.layers[0].attention.wq.data()
print(f"Quantization: {weight_data.quantization_encoding}")
print(f"Shape: {weight_data.shape}")
# Allocate with quantization validation
ffn_weight = weights.model.layers[0].feed_forward.w1.allocate(
quantization_encoding=QuantizationEncoding.Q4_K,
device=DeviceRef.GPU(0)
)
# Iterate through all weights in a layer
for name, weight in weights.model.layers[0].items():
if weight.exists():
print(f"Found weight: {name}")
Creates a GGUF weights reader.
-
Parameters:
-
- source (Union[PathLike, gguf.GGUFReader]) – Path to a GGUF file or a GGUFReader object.
- tensors – List of tensors in the GGUF checkpoint.
- prefix (str) – Weight name or prefix.
- allocated – Dictionary of allocated values.
allocate()
allocate(dtype=None, shape=None, quantization_encoding=None, device=cpu:0)
Creates and optionally validates a new Weight.
allocated_weights
property allocated_weights: dict[str, DLPackArray]
Gets the values of all weights that were allocated previously.
data()
data()
Get weight data with metadata.
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
-
Returns:
-
A WeightData object containing the tensor data along with metadata like name, dtype, shape, and quantization encoding.
-
Raises:
-
KeyError – If no weight exists at the current hierarchical name.
-
Return type:
exists()
exists()
Check if a weight with this exact name exists.
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
-
Returns:
-
True if a weight with the current hierarchical name exists in the loaded weights, False otherwise.
-
Return type:
items()
items()
Iterate through all allocable weights that start with the prefix.
name
property name: str
The current weight name or prefix.
PytorchWeights
class max.graph.weights.PytorchWeights(filepath, tensor_infos=None, prefix='', allocated=None)
Implementation for loading weights from PyTorch checkpoint files.
PytorchWeights
provides an interface to load model weights from PyTorch
checkpoint files (.bin or .pt format). These files contain serialized
PyTorch tensors using Python’s pickle protocol, making them widely compatible
with the PyTorch ecosystem.
from pathlib import Path
from max.graph.weights import PytorchWeights
from max.dtype import DType
# Load weights from PyTorch checkpoint
checkpoint_path = Path("pytorch_model.bin")
weights = PytorchWeights(checkpoint_path)
# Check if a weight exists before allocation
if weights.model.decoder.layers[0].self_attn.q_proj.weight.exists():
# Allocate the attention weight
q_weight = weights.model.decoder.layers[0].self_attn.q_proj.weight.allocate(
dtype=DType.float32,
device=DeviceRef.CPU()
)
# Access weight properties
if weights.embeddings.weight.exists():
print(f"Embedding shape: {weights.embeddings.weight.shape}")
print(f"Embedding dtype: {weights.embeddings.weight.dtype}")
# Allocate with validation
embedding_weight = weights.embeddings.weight.allocate(
dtype=DType.float16,
shape=(50257, 768) # Validate expected shape
)
from pathlib import Path
from max.graph.weights import PytorchWeights
from max.dtype import DType
# Load weights from PyTorch checkpoint
checkpoint_path = Path("pytorch_model.bin")
weights = PytorchWeights(checkpoint_path)
# Check if a weight exists before allocation
if weights.model.decoder.layers[0].self_attn.q_proj.weight.exists():
# Allocate the attention weight
q_weight = weights.model.decoder.layers[0].self_attn.q_proj.weight.allocate(
dtype=DType.float32,
device=DeviceRef.CPU()
)
# Access weight properties
if weights.embeddings.weight.exists():
print(f"Embedding shape: {weights.embeddings.weight.shape}")
print(f"Embedding dtype: {weights.embeddings.weight.dtype}")
# Allocate with validation
embedding_weight = weights.embeddings.weight.allocate(
dtype=DType.float16,
shape=(50257, 768) # Validate expected shape
)
allocate()
allocate(dtype=None, shape=None, quantization_encoding=None, device=cpu:0)
Creates and optionally validates a new Weight.
allocated_weights
property allocated_weights: dict[str, DLPackArray]
Gets the values of all weights that were allocated previously.
data()
data()
-
Return type:
dtype
property dtype: DType
The current weight dtype, if this weight exists.
exists()
exists()
-
Return type:
items()
items()
Iterate through all allocable weights that start with the prefix.
name
property name: str
The current weight name or prefix.
quantization_encoding
property quantization_encoding: QuantizationEncoding | None
The current weight quantization encoding, if this weight exists.
shape
property shape: Shape
The current weight shape, if this weight exists.
SafetensorWeights
class max.graph.weights.SafetensorWeights(filepaths, *, tensors=None, tensors_to_file_idx=None, prefix='', allocated=None, _st_weight_map=None, _st_file_handles=None)
Implementation for loading weights from safetensors files.
SafetensorWeights provides a secure and efficient way to load model weights from safetensors format files. Safetensors is designed by Hugging Face for safe serialization that prevents arbitrary code execution and supports memory-mapped loading for fast access.
from pathlib import Path
from max.graph.weights import SafetensorWeights
from max.dtype import DType
# Load weights from safetensors files
weight_files = [Path("model.safetensors")]
weights = SafetensorWeights(weight_files)
# Check if a weight exists
if weights.model.embeddings.weight.exists():
# Allocate the embedding weight
embedding_weight = weights.model.embeddings.weight.allocate(
dtype=DType.float32,
device=DeviceRef.CPU()
)
# Access weights with hierarchical naming
attn_weight = weights.transformer.layers[0].attention.weight.allocate(
dtype=DType.float16
)
from pathlib import Path
from max.graph.weights import SafetensorWeights
from max.dtype import DType
# Load weights from safetensors files
weight_files = [Path("model.safetensors")]
weights = SafetensorWeights(weight_files)
# Check if a weight exists
if weights.model.embeddings.weight.exists():
# Allocate the embedding weight
embedding_weight = weights.model.embeddings.weight.allocate(
dtype=DType.float32,
device=DeviceRef.CPU()
)
# Access weights with hierarchical naming
attn_weight = weights.transformer.layers[0].attention.weight.allocate(
dtype=DType.float16
)
-
Parameters:
allocate()
allocate(dtype=None, shape=None, quantization_encoding=None, device=cpu:0)
Creates a Weight that can be added to a graph.
allocate_as_bytes()
allocate_as_bytes(dtype=None)
Create a Weight that can be added to the graph. Has a uint8 representation, instead of the original data type. Last dimension of the scale gets scaled by number of bytes it takes to represent the original data type. For example, [512, 256] float32 weights become [512, 1024] uint8 weights. Scalar weights will be interpreted as weights with shape [1].
allocated_weights
property allocated_weights: dict[str, DLPackArray]
Gets the values of all weights that were allocated previously.
data()
data()
Get weight data with metadata.
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
-
Returns:
-
A WeightData object containing the tensor data along with metadata like name, dtype, shape, and quantization encoding.
-
Raises:
-
KeyError – If no weight exists at the current hierarchical name.
-
Return type:
exists()
exists()
Check if a weight with this exact name exists.
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
-
Returns:
-
True if a weight with the current hierarchical name exists in the loaded weights, False otherwise.
-
Return type:
items()
items()
Iterate through all allocable weights that start with the prefix.
name
property name: str
The current weight name or prefix.
WeightData
class max.graph.weights.WeightData(data, name, dtype, shape, quantization_encoding=None)
Container for weight tensor data with metadata.
WeightData
encapsulates a weight tensor along with its metadata,
providing utilities for type conversion and format compatibility.
It supports the DLPack protocol for efficient tensor sharing between
frameworks.
-
Parameters:
-
- data (DLPackArray)
- name (str)
- dtype (DType)
- shape (Shape)
- quantization_encoding (QuantizationEncoding | None)
astype()
astype(dtype)
Convert the weight data to a different dtype.
This method performs actual data conversion, unlike view()
which
reinterprets the underlying bytes. Special handling is provided for
bfloat16 conversions using PyTorch when available.
# Convert float32 weights to float16 for reduced memory
weight_data = weights.model.layer.weight.data()
fp16_data = weight_data.astype(DType.float16)
# Convert float32 weights to float16 for reduced memory
weight_data = weights.model.layer.weight.data()
fp16_data = weight_data.astype(DType.float16)
-
Parameters:
-
dtype (DType) – Target data type for conversion.
-
Returns:
-
A new WeightData instance with the converted data.
-
Return type:
data
data: DLPackArray
The weight tensor as a DLPack array.
dtype
dtype: DType
Data type of the tensor (for example, DType.float32
, DType.uint8
).
from_numpy()
classmethod from_numpy(arr, name)
Create WeightData from a numpy array.
-
Parameters:
-
- arr – Numpy array containing the weight data.
- name – Name to assign to this weight.
-
Returns:
-
A new WeightData instance with dtype and shape inferred from the numpy array.
name
name: str
Hierarchical name of the weight (for example, "model.layers.0.weight"
).
quantization_encoding
quantization_encoding: QuantizationEncoding | None = None
Optional quantization scheme applied to the weight.
shape
shape: Shape
Shape of the tensor as a Shape object.
Weights
class max.graph.weights.Weights(*args, **kwargs)
Protocol for managing and accessing model weights hierarchically.
The Weights protocol provides a convenient interface for loading and organizing neural network weights. It supports hierarchical naming through attribute and index access, making it easy to work with complex model architectures.
Weights in MAX are tensors backed by external memory (buffers or memory-mapped files) that remain separate from the compiled graph.
from max.graph import Graph
from max.dtype import DType
# Create a graph and get its weights interface
graph = Graph("my_model")
weights = graph.weights()
# Allocate weights with hierarchical naming
attn_weight = weights.transformer.layers[0].attention.weight.allocate(
dtype=DType.float32,
shape=(768, 768)
)
# Creates weight named "transformer.layers.0.attention.weight"
# Check if a weight exists before allocating
if weights.transformer.layers[0].mlp.weight.exists():
mlp_weight = weights.transformer.layers[0].mlp.weight.allocate(
dtype=DType.float16,
shape=(768, 3072)
)
from max.graph import Graph
from max.dtype import DType
# Create a graph and get its weights interface
graph = Graph("my_model")
weights = graph.weights()
# Allocate weights with hierarchical naming
attn_weight = weights.transformer.layers[0].attention.weight.allocate(
dtype=DType.float32,
shape=(768, 768)
)
# Creates weight named "transformer.layers.0.attention.weight"
# Check if a weight exists before allocating
if weights.transformer.layers[0].mlp.weight.exists():
mlp_weight = weights.transformer.layers[0].mlp.weight.allocate(
dtype=DType.float16,
shape=(768, 3072)
)
allocate()
allocate(dtype=None, shape=None, quantization_encoding=None, device=cpu:0)
Create a Weight object for this tensor.
# Allocate a weight with specific configuration
weight = weights.model.layers[0].weight.allocate(
dtype=DType.float16, # Convert to half precision
shape=(768, 768),
device=DeviceRef.GPU(0) # Place on first GPU
)
# Add to graph
with graph:
weight_tensor = graph.add_weight(weight)
# Allocate a weight with specific configuration
weight = weights.model.layers[0].weight.allocate(
dtype=DType.float16, # Convert to half precision
shape=(768, 768),
device=DeviceRef.GPU(0) # Place on first GPU
)
# Add to graph
with graph:
weight_tensor = graph.add_weight(weight)
-
Parameters:
-
- dtype (DType | None) – Data type for the weight. If
None
, uses the original dtype. - shape (Iterable[int | str | Dim | integer] | None) – Shape of the weight tensor. If
None
, uses the original shape. - quantization_encoding (QuantizationEncoding | None) – Quantization scheme to apply (for example,
Q4_K
,Q8_0
). - device (DeviceRef) – Target device for the weight (CPU or GPU).
- dtype (DType | None) – Data type for the weight. If
-
Returns:
-
A Weight object that can be added to a graph using
graph.add_weight()
. -
Return type:
allocated_weights
property allocated_weights: dict[str, DLPackArray]
- Get all previously allocated weights. This only includes weights that were explicitly allocated
- using the
allocate()
method, not all available weights.
-
Returns:
-
A dictionary mapping weight names to their numpy arrays for all weights that have been allocated through this interface.
data()
data()
Get weight data with metadata.
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
weight_data = weights.model.embeddings.weight.data()
print(f"Shape: {weight_data.shape}")
print(f"Dtype: {weight_data.dtype}")
# Convert to different dtype
fp16_data = weight_data.astype(DType.float16)
-
Returns:
-
A WeightData object containing the tensor data along with metadata like name, dtype, shape, and quantization encoding.
-
Raises:
-
KeyError – If no weight exists at the current hierarchical name.
-
Return type:
exists()
exists()
Check if a weight with this exact name exists.
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
if weights.model.classifier.weight.exists():
classifier = weights.model.classifier.weight.allocate(...)
else:
print("Classifier weight not found")
-
Returns:
-
True if a weight with the current hierarchical name exists in the loaded weights, False otherwise.
-
Return type:
items()
items()
Iterate through all weights that start with the current prefix.
# Iterate through all weights in a specific layer
for name, weight in weights.transformer.layers[0].items():
print(f"Found weight: {name}")
# Iterate through all weights in a specific layer
for name, weight in weights.transformer.layers[0].items():
print(f"Found weight: {name}")
name
property name: str
Get the current weight name or prefix.
-
Returns:
-
The hierarchical name built from attribute and index access. For example, if accessed as
weights.model.layers[0]
, returns “model.layers.0”.
WeightsFormat
class max.graph.weights.WeightsFormat(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Enumeration of supported weight file formats.
MAX supports multiple weight formats to accommodate different model sources and use cases.
gguf
gguf = 'gguf'
GGUF (GPT-Generated Unified Format) for quantized models.
File extension: .gguf
Optimized for quantized large language models, particularly those from the
llama.cpp ecosystem. Supports multiple quantization schemes (Q4_K
,
Q5_K
, Q8_0
, etc.) and includes model metadata in the file.
pytorch
pytorch = 'pytorch'
PyTorch checkpoint format for model weights.
File extension: .bin
| .pt
| .pth
Standard PyTorch format using Python’s pickle protocol. Widely supported but requires caution as pickle files can execute arbitrary code.
safetensors
safetensors = 'safetensors'
Safetensors format for secure and efficient tensor storage.
File extension: .safetensors
Designed by Hugging Face for safe serialization that prevents arbitrary code execution. Uses memory-mapped files for fast loading and supports sharding across multiple files.
load_weights()
max.graph.weights.load_weights(paths)
Loads neural network weights from checkpoint files.
Automatically detects checkpoint formats based on file extensions and returns the appropriate Weights implementation, creating a seamless interface for loading weights from different formats.
Supported formats:
- Safetensors: .safetensors
- PyTorch: .bin, .pt, .pth
- GGUF: .gguf
The following example shows how to load weights from a Safetensors file:
from pathlib import Path
from max.graph.weights import load_weights
# Load multi-file checkpoints
sharded_paths = [
Path("model-00001-of-00003.safetensors"),
Path("model-00002-of-00003.safetensors"),
Path("model-00003-of-00003.safetensors")
]
weights = load_weights(sharded_paths)
layer_weight = weights.model.layers[23].mlp.gate_proj.weight.allocate(
dtype=DType.float32,
shape=[4096, 14336],
device=DeviceRef.GPU(0)
)
from pathlib import Path
from max.graph.weights import load_weights
# Load multi-file checkpoints
sharded_paths = [
Path("model-00001-of-00003.safetensors"),
Path("model-00002-of-00003.safetensors"),
Path("model-00003-of-00003.safetensors")
]
weights = load_weights(sharded_paths)
layer_weight = weights.model.layers[23].mlp.gate_proj.weight.allocate(
dtype=DType.float32,
shape=[4096, 14336],
device=DeviceRef.GPU(0)
)
weights_format()
max.graph.weights.weights_format(weight_paths)
Detect the format of weight files based on their extensions.
This function examines the file extensions of all provided paths to determine the weight format. All files must have the same format; mixed formats are not supported.
from pathlib import Path
# Detect format for safetensor files
paths = [Path("model-00001.safetensors"), Path("model-00002.safetensors")]
format = weights_format(paths)
print(format) # WeightsFormat.safetensors
from pathlib import Path
# Detect format for safetensor files
paths = [Path("model-00001.safetensors"), Path("model-00002.safetensors")]
format = weights_format(paths)
print(format) # WeightsFormat.safetensors
-
Parameters:
-
weight_paths (list[Path]) – List of file paths containing model weights. All files must have the same extension/format.
-
Returns:
-
The detected WeightsFormat enum value.
-
Raises:
-
ValueError – If weight_paths is empty, contains mixed formats, or has unsupported file extensions.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!