Python function
rms_norm
rms_norm()
max.experimental.nn.norm.rms_norm(input, weight, epsilon, weight_offset=0.0, multiply_before_cast=False)
Performs Root Mean Square layer normalization.
Computes output = input / rms(input) * weight where
rms(x) = sqrt(mean(x^2) + epsilon).
When multiply_before_cast is False (Llama-style), the input is
cast to the output dtype before multiplication by the weight. When
True (Gemma-style), the multiplication is performed before the cast.
-
Parameters:
-
- input (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The input tensor to normalize.
- weight (Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray) – The weight tensor whose shape must match the last dimension
of
input. - epsilon (float) – A small value added to the denominator for numerical stability.
- weight_offset (float) – A value added to the weight before normalization.
Typically
1for Gemma-like normalization and0otherwise. - multiply_before_cast (bool) – Whether to multiply before casting to the output dtype.
-
Returns:
-
A normalized tensor with the same shape and dtype as
input. -
Raises:
-
ValueError – If weight shape doesn’t match the last dimension of input.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!