Skip to main content

Python function

rms_norm

rms_norm()

max.experimental.nn.norm.rms_norm(input, weight, epsilon, weight_offset=0.0, multiply_before_cast=False)

source

Performs Root Mean Square layer normalization.

Computes output = input / rms(input) * weight where rms(x) = sqrt(mean(x^2) + epsilon).

When multiply_before_cast is False (Llama-style), the input is cast to the output dtype before multiplication by the weight. When True (Gemma-style), the multiplication is performed before the cast.

Parameters:

  • input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to normalize.
  • weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The weight tensor whose shape must match the last dimension of input.
  • epsilon (float) – A small value added to the denominator for numerical stability.
  • weight_offset (float) – A value added to the weight before normalization. Typically 1 for Gemma-like normalization and 0 otherwise.
  • multiply_before_cast (bool) – Whether to multiply before casting to the output dtype.

Returns:

A normalized tensor with the same shape and dtype as input.

Raises:

ValueError – If weight shape doesn’t match the last dimension of input.

Return type:

TensorValue