IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python function

rms_norm

rms_norm()

max.experimental.nn.norm.rms_norm(input, weight, epsilon, weight_offset=0.0, multiply_before_cast=False)

source

Performs Root Mean Square layer normalization.

Computes output = input / rms(input) * weight where rms(x) = sqrt(mean(x^2) + epsilon).

When multiply_before_cast is False (Llama-style), the input is cast to the output dtype before multiplication by the weight. When True (Gemma-style), the multiplication is performed before the cast.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to normalize.
  • weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The weight tensor whose shape must match the last dimension of input.
  • epsilon (float) – A small value added to the denominator for numerical stability.
  • weight_offset (float) – A value added to the weight before normalization. Typically 1 for Gemma-like normalization and 0 otherwise.
  • multiply_before_cast (bool) – Whether to multiply before casting to the output dtype.

Returns:

A normalized tensor with the same shape and dtype as input.

Raises:

ValueError – If weight shape doesn’t match the last dimension of input.

Return type:

TensorValue