IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

RMSNorm

RMSNorm​

class max.experimental.nn.norm.RMSNorm(dim, eps=1e-06)

source

Bases: Module

Root mean square normalization over the last dimension of the input.

Unlike LayerNorm, the mean is not subtracted; only the root-mean-square is used to rescale. See Root Mean Square Layer Normalization for the formulation. For the Gemma variant that uses 1 + weight and multiplies before casting back, see GemmaRMSNorm.

For example:

from max.dtype import DType
from max.experimental.nn.norm import RMSNorm
from max.experimental.realization_context import (
    GraphRealizationContext,
    realization_context,
)
from max.experimental.tensor import Tensor
from max.graph import DeviceRef, Graph, TensorType

graph = Graph(
    "rms",
    input_types=[
        TensorType(DType.float32, ("batch", "seq", 2048), DeviceRef.GPU()),
    ],
)
ctx = GraphRealizationContext(graph)
with realization_context(ctx), ctx:
    x = Tensor.from_graph_value(graph.inputs[0])
    norm = RMSNorm(2048, eps=1e-6)
    y = norm(x)
    graph.output(y)

Parameters:

  • dim (int) – The size of the last dimension of the input.
  • eps (float) – A small positive constant added to the mean of squares for numerical stability. Defaults to 1e-6.

dim​

property dim: Dim

source

The size of the last dimension over which normalization runs.

eps​

eps: float

source

The variance epsilon used for numerical stability.

forward()​

forward(x)

source

Returns x normalized by its root-mean-square over the last axis.

Parameters:

x (Tensor)

Return type:

Tensor

weight​

weight: Tensor

source

The learned per-element scale of shape [dim].