Python class
Module
Module
class max.nn.Module
Base class for model components with weight management.
Provides functionality to create custom layers and construct networks with automatic weight tracking.
The following example uses the Module class to create custom layers and build a neural network:
from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef
class Linear(nn.Module):
def __init__(self, in_dims, out_dims):
super().__init__()
self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())
def __call__(self, x):
return x @ self.weight.T
class MLP(nn.Module):
def __init__(self):
self.up = Linear(5, 10)
self.gate = Linear(5, 10)
self.down = Linear(10, 5)
def __call__(self, x):
return self.down(ops.silu(self.gate(x)) + self.up(x))
model = MLP()
print(model.state_dict()) # {"up.weight": Buffer([5, 10]), ...}Constructing a graph without Module can result in name collisions
with the weights (in this example, there would be three weights with the
name Weight). With Module, you can use state_dict() or
load_state_dict() to initialize or set the weights values, and finalize
the weight names to be unique within the model.
build_subgraph()
build_subgraph(name, input_types, weight_prefix='')
Builds a subgraph encapsulating this layer’s computation.
Call this method once on a representative layer, then call the returned
subgraph once per layer using call() with a unique
prefix. This pattern lets the compiler process the layer definition
once rather than once per repetition, which significantly reduces compile
time for models with many identical layers.
Examples:
Build a subgraph from layer 0 and call it once per layer with layer-specific weights:
input_types = [hidden.type for hidden in h]
subgraph = self.layers[0].build_subgraph(
"transformer_block",
input_types=input_types,
weight_prefix="layers.0.",
)
# Call it once per layer with the correct weight prefix.
for idx in range(len(self.layers)):
outputs = ops.call(
subgraph, *h, prefix=f"layers.{idx}."
)
h = [x.tensor for x in outputs]-
Parameters:
-
- name (str) – The name of the subgraph. Must be unique within the containing graph.
- input_types (Sequence[Type[Any] | list[Type[Any]]]) – The input types for the subgraph. Pass a flat
Typefor a single tensor, or a list ofTypeobjects for a group of tensors that should be passed together (for example, KV-cache blocks). - weight_prefix (str) – A prefix string to strip from weight names before
registering them as placeholder weights. At call time, the caller
supplies the same prefix via the
prefixargument ofcall()to re-resolve each weight to the correct entry in the weights registry.
-
Returns:
-
A
Graphinstance representing the subgraph. -
Return type:
Notes:
Weights with names that start with weight_prefix are marked as
placeholders. Any call() invocation for this
subgraph must supply a matching prefix.
layer_weights
Returns a mapping from weight name to Weight for this layer.
load_state_dict()
load_state_dict(state_dict, *, override_quantization_encoding=False, weight_alignment=None, strict=True)
Sets the values of all weights in this model.
-
Parameters:
-
- state_dict (Mapping[str, DLPackArray | WeightData]) – A map from weight name to a numpy array or
Buffer. - override_quantization_encoding (bool) – Whether to override the weight quantization based on the loaded value.
- weight_alignment (int | None) – If specified, overrides the alignment for each weight in the Module. If left as None, each value in state_dict must be aligned by the default dtype alignment.
- strict (bool) – If True, raises an error if any weights required by the Module are missing from state_dict, or if any keys in state_dict were not used by the Module. If False, both missing and unexpected keys are tolerated and reported only via return values/logging by callers.
- state_dict (Mapping[str, DLPackArray | WeightData]) – A map from weight name to a numpy array or
-
Raises:
-
ValueError – If strict is True and any required weight is missing from state_dict, or if state_dict contains keys not used by the Module.
-
Return type:
-
None
raw_state_dict()
raw_state_dict()
Returns all weights objects in the model.
Unlike state_dict(), this returns Weight objects instead of
the assigned values. Some parameters inside the Weight can be
configured before a graph is built. Do not change these attributes after
building a graph:
set_shared_weight()
set_shared_weight(name, weight)
Registers a Weight as shared on this layer.
Sets name as an attribute on this layer and marks the weight as
shared so that raw_state_dict() and load_state_dict() skip
it when iterating over owned weights.
state_dict()
state_dict(auto_initialize=True)
Returns values of all weights in the model.
The values returned are the same as the values set in load_state_dict().
If load_state_dict() has not been called and none of the weights have
values, then they are initialized to zero.
sublayers
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!