IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python function

subgraphable

subgraphable()

max.experimental.nn.subgraphable(module, *, name=None)

source

Lowers a repeated Module to one shared subgraph.

Inside Module.compile() / Module.trace(), each call emits one mo.call into a subgraph, and calls whose bodies trace to identical IR share a single definition instead of inlining each time. A Module threads its parameters in as call operands, so identical sibling modules share one body while each computes with its own (distributed) weights. Calls nested inside a subgraph body inline. Calling outside a capture raises.

Use it as a class decorator so an ordinary layer loop auto-shares a body:

@subgraphable
@module_dataclass
class Block(Module[[Tensor], Tensor]): ...

def forward(self, x):
    for layer in self.layers:  # each call -> one shared subgraph
        x = layer(x)
    return x

Or wrap a single Module call directly: subgraphable(layer)(x).

Two calls share a body when their traced IR is identical: same ops and same operand types. Tensors (weights and Tensor arguments, positional or keyword) flow in as operands, so only structure matters; non-Tensor arguments bake into the body. A value the body bakes in (a different op mix, a constant read from a field or argument) yields a distinct body; a field or argument the body never reads does not.

Parameters:

  • module (Any) – The class to mark (class-decorator form), or the Module instance to wrap (call form).
  • name (str | None) – Subgraph name override. Defaults to the class name.

Returns:

The decorated class (class-decorator form) or a wrapper that emits the mo.call on each invocation (call form).

Return type:

Callable[[…], Any]