For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python class
RMSNorm
RMSNormβ
class max.experimental.nn.norm.RMSNorm(dim, eps=1e-06)
Bases: Module
Root mean square normalization over the last dimension of the input.
Unlike LayerNorm, the mean is not subtracted; only the
root-mean-square is used to rescale. See Root Mean Square Layer
Normalization for the formulation.
For the Gemma variant that uses 1 + weight and multiplies before
casting back, see GemmaRMSNorm.
For example:
from max.dtype import DType
from max.experimental.nn.norm import RMSNorm
from max.experimental.realization_context import (
GraphRealizationContext,
realization_context,
)
from max.experimental.tensor import Tensor
from max.graph import DeviceRef, Graph, TensorType
graph = Graph(
"rms",
input_types=[
TensorType(DType.float32, ("batch", "seq", 2048), DeviceRef.GPU()),
],
)
ctx = GraphRealizationContext(graph)
with realization_context(ctx), ctx:
x = Tensor.from_graph_value(graph.inputs[0])
norm = RMSNorm(2048, eps=1e-6)
y = norm(x)
graph.output(y)-
Parameters:
dimβ
property dim: Dim
The size of the last dimension over which normalization runs.
epsβ
eps: float
The variance epsilon used for numerical stability.
forward()β
forward(x)
Returns x normalized by its root-mean-square over the last axis.
weightβ
weight: Tensor
The learned per-element scale of shape [dim].
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!