Skip to main content

Python module

layer

Layer

class max.nn.layer.Layer

Deprecated

Deprecated since version 25.2..

Base class for neural network components. Use Module instead.

Provides functionality for adding hooks to the call function of each layer to support testing, debugging or profiling.

LayerList

class max.nn.layer.LayerList(layers)

Stores a list of layers.

Can be used as a regular python list.

Parameters:

layers (Sequence [ Layer ] )

append()

append(layer)

Parameters:

layer (Layer )

extend()

extend(layer)

Parameters:

layer (Layer )

insert()

insert(i, layer)

Parameters:

layer (Layer )

sublayers

property sublayers: dict[str, Module]

Module

class max.nn.layer.Module

Base class for model components with weight management.

Provides functionality to create custom layers and construct networks with automatic weight tracking.

The following example uses the Module class to create custom layers and build a neural network:

from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef

class Linear(nn.Module):
def __init__(self, in_dims, out_dims):
super().__init__()
self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())

def __call__(self, x):
return x @ self.weight.T

class MLP(nn.Module):
def __init__(self):
self.up = Linear(5, 10)
self.gate = Linear(5, 10)
self.down = Linear(10, 5)

def __call__(self, x):
return self.down(ops.silu(self.gate(x)) + self.up(x))

model = MLP()
print(model.state_dict()) # {"up.weight": Tensor([5, 10]), ...}
from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef

class Linear(nn.Module):
def __init__(self, in_dims, out_dims):
super().__init__()
self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())

def __call__(self, x):
return x @ self.weight.T

class MLP(nn.Module):
def __init__(self):
self.up = Linear(5, 10)
self.gate = Linear(5, 10)
self.down = Linear(10, 5)

def __call__(self, x):
return self.down(ops.silu(self.gate(x)) + self.up(x))

model = MLP()
print(model.state_dict()) # {"up.weight": Tensor([5, 10]), ...}

Constructing a graph without Module can result in name collisions with the weights (in this example, there would be three weights with the name Weight). With Module, you can use state_dict() or load_state_dict() to initialize or set the weights values, and finalize the weight names to be unique within the model.

build_subgraph()

build_subgraph(name, input_types, weight_prefix='')

Builds a subgraph for this module.

This method creates a subgraph that encapsulates the module’s logic, handling input types, weights, and creating a graph with the module’s computation.

Once the subgraph is built, it can be called using the ops.call op.

Parameters:

  • name (str ) – The name of the subgraph to create.
  • input_types (Sequence [ Type | list [ Type ] ] ) – A list of input types for the subgraph. Each element can be either a single Type or a list of Type objects.
  • weight_prefix (str ) – Optional prefix for weight names in the subgraph. If provided, weights with names starting with this prefix will have their names modified by removing the prefix and will be marked as placeholders.

Returns:

The created subgraph containing the module’s computation.

Return type:

Graph

NOTE

  • Placeholder weights will require the prefix attribute of ops.call to be set.

layer_weights

property layer_weights: dict[str, Weight]

load_state_dict()

load_state_dict(state_dict, *, override_quantization_encoding=False, weight_alignment=None, strict=True)

Sets the values of all weights in this model.

Parameters:

  • state_dict (Mapping [ str , DLPackArray | ndarray | WeightData ] ) – A map from weight name to a numpy array or max.driver.Tensor.
  • override_quantization_encoding (bool ) – Whether to override the weight quantization based on the loaded value.
  • weight_alignment (int | None ) – If specified, overrides the alignment for each weight in the Module. If left as None, each value in state_dict must be aligned by the default dtype alignment.
  • strict (bool ) – If True, raises an error if any keys in state_dict were not used by the Module.

Raises:

  • ValueError – If any weight in the model is not present in the state dict.
  • ValueError – If strict is True and state_dict contains keys not used by the Module.

Return type:

None

raw_state_dict()

raw_state_dict()

Returns all weights objects in the model. Unlike state_dict, this returns max.graph.Weight objects instead of the assigned values. Some parameters inside the Weight can be configured before a graph is built. Do not change these attributes after building a graph:

Returns:

Map from weight name to the max.graph.Weight object.

Return type:

dict[str, Weight]

set_shared_weight()

set_shared_weight(name, weight)

Parameters:

state_dict()

state_dict(auto_initialize=True)

Returns values of all weights in the model.

The values returned are the same as the values set in load_state_dict. If load_state_dict has not been called and none of the weights have values, then they are initialized to zero.

Parameters:

auto_initialize (bool ) – Determines whether to initialize weights to zero if the weight value has not been loaded. If this is False, a ValueError is raised if an uninitialized weight is found.

Returns:

Map from weight name to the weight value (can be numpy array or max.driver.Tensor).

Return type:

dict[str, DLPackArray | ndarray]

sublayers

property sublayers: dict[str, Module]

Shardable

class max.nn.layer.Shardable(*args, **kwargs)

Protocol for objects that support sharding across multiple devices.

This protocol defines the interface that all shardable components (like Linear layers and Weight objects) must implement to participate in distributed computation.

shard()

shard(shard_idx, device)

Creates a sharded view of this object for a specific device.

Parameters:

  • shard_idx (int ) – The index of the shard (0 to num_devices-1).
  • device (DeviceRef ) – The device where this shard should reside.

Returns:

A sharded instance of this object.

Return type:

Shardable

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Gets the weight sharding strategy.

add_layer_hook()

max.nn.layer.add_layer_hook(fn)

Adds a hook to call a function after each layer’s __call__.

The function will be passed four inputs:

  • layer
  • input_args
  • input_kwargs
  • outputs

The function can either return None or new outputs that will replace the layer returned outputs.

Note that input and outputs contain graph Values, which show limited information (like shape and dtype). You can still see the computed values if you include the Value in the graph.ops.output op, or call graph.ops.print.

Example of printing debug inputs:

def print_info(layer, args, kwargs, outputs):
print("Layer:", type(layer).__name__)
print("Input args:", args)
print("Input kwargs:", kwargs)
print("Outputs:", outputs)
return outputs

add_layer_hook(print_info)
def print_info(layer, args, kwargs, outputs):
print("Layer:", type(layer).__name__)
print("Input args:", args)
print("Input kwargs:", kwargs)
print("Outputs:", outputs)
return outputs

add_layer_hook(print_info)

Parameters:

fn (Callable [ [ Layer , tuple [ Any , ... ] , dict [ str , Any ] , Any ] , Any ] )

Return type:

None

clear_hooks()

max.nn.layer.clear_hooks()

Remove all hooks.

recursive_named_layers()

max.nn.layer.recursive_named_layers(parent, prefix='')

Recursively walks through the layers and generates names.

Parameters:

Return type:

Iterable[tuple[str, Module]]