Skip to main content

Python class

SafetensorWeights

SafetensorWeights

class max.graph.weights.SafetensorWeights(filepaths, *, tensors=None, tensors_to_file_idx=None, prefix='', allocated=None, _st_weight_map=None, _st_file_handles=None)

source

Bases: Weights

Implementation for loading weights from safetensors files.

SafetensorWeights provides a secure and efficient way to load model weights from safetensors format files. Safetensors is designed by Hugging Face for safe serialization that prevents arbitrary code execution and supports memory-mapped loading for fast access.

from pathlib import Path
from max.graph.weights import SafetensorWeights
from max.dtype import DType

weight_files = [Path("model.safetensors")]
weights = SafetensorWeights(weight_files)

# Check if a weight exists
if weights.model.embeddings.weight.exists():
    # Allocate the embedding weight
    embedding_weight = weights.model.embeddings.weight.allocate(
        dtype=DType.float32,
        device=DeviceRef.CPU()
    )

# Access weights with hierarchical naming
attn_weight = weights.transformer.layers[0].attention.weight.allocate(
    dtype=DType.float16
)

Parameters:

allocate()

allocate(dtype=None, shape=None, quantization_encoding=None, device=cpu:0)

source

Creates a Weight that can be added to a graph.

Parameters:

Return type:

Weight

allocate_as_bytes()

allocate_as_bytes(dtype=None)

source

Creates a Weight that can be added to the graph with uint8 representation.

The last dimension is scaled by the number of bytes of the original dtype (for example, [512, 256] float32 becomes [512, 1024] uint8). Scalars are interpreted as shape [1].

Parameters:

dtype (DType | None)

Return type:

Weight

allocated_weights

property allocated_weights: dict[str, DLPackArray]

source

Gets the values of all weights that were allocated previously.

data()

data()

source

Loads and returns the weight data for this tensor.

Return type:

WeightData

exists()

exists()

source

Returns True if a tensor exists for the current name.

Return type:

bool

items()

items()

source

Iterates through all allocable weights that start with the prefix.

name

property name: str

source

The current weight name or prefix.