Python class
Weight
Weight
class max.graph.Weight(*args, **kwargs)
Bases: TensorValue
Represents a value in a Graph that can be loaded at a later time.
Weights can be initialized outside of a Graph and are lazily-added to the parent graph when used. If there is no parent graph when a weight is used, an error will be raised.
align
dtype
property dtype*: DType*
Returns the tensor data type.
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
with Graph("dtype_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32)
print(f"Data type: {tensor.dtype}") # Output: Data type: DType.float32
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
with Graph("dtype_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32)
print(f"Data type: {tensor.dtype}") # Output: Data type: DType.float32
original_dtype_and_shape
property original_dtype_and_shape*: tuple[max._core.dtype.DType, max.graph.type.Shape]*
The original dtype and shape of this weight.
This property should be used to store the original weight’s dtype and shape the quantization encoding forces the weight to be loaded as uint8.
quantization_encoding
quantization_encoding*: QuantizationEncoding | None*
set_sharding_strategy()
set_sharding_strategy(sharding_strategy: Callable[[Weight, int], TensorValue]) → None
Set the weight sharding strategy.
-
Parameters:
sharding_strategy – A callable that takes the host weight and shard index, and returns the sharded value.
shape
property shape*: Shape*
Returns the shape of the TensorValue
.
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("shape_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32)
# Access tensor properties
print(f"Shape: {tensor.shape}") # Output: Shape: (2, 2)
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)
# Create a Graph context to work with tensors
with Graph("shape_demo") as graph:
# Create a constant tensor from the matrix
tensor = ops.constant(matrix, dtype=DType.float32)
# Access tensor properties
print(f"Shape: {tensor.shape}") # Output: Shape: (2, 2)
shard()
shard(shard_idx: int, device: DeviceRef | None = None) → Weight
Gets a specific shard from the Weight.
This Weight must have sharding_strategy defined. The shard object returned is also a Weight object, but cannot be sharded further.
-
Parameters:
- shard_idx – int value of the shard.
- device – Optional device to place the shard.
-
Returns:
The sharded weight.
shard_idx
sharding_strategy
sharding_strategy*: _ShardingStrategyContainer | None*
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!