Python module
rms_norm
Root mean square layer normalization.
RMSNorm
class max.nn.norm.rms_norm.RMSNorm(dim, eps=1e-06)
Computes the Root Mean Square normalization on inputs.
dim
property dim: Dim
eps
eps: float
forward()
forward(x)
Defines the computation performed by the module.
Users must override this method in their subclass to define the module’s computation.
-
Parameters:
-
- *args – Positional arguments for the computation.
- **kwargs – Keyword arguments for the computation.
- x (Tensor)
-
Returns:
-
The result of applying the module to the input.
-
Raises:
-
NotImplementedError – If the subclass does not override this method.
-
Return type:
weight
weight: Tensor
rms_norm()
max.nn.norm.rms_norm.rms_norm(x, weight, eps, weight_offset=0.0, multiply_before_cast=False)
Applies Root Mean Square layer normalization to an input tensor.
See https://arxiv.org/abs/1910.07467
-
Parameters:
-
- x (Tensor) – The input tensor
- weight (Tensor) – The weights for the normalization
- eps (float) – A value added to the denominator of the normalization for numerical stability
- weight_offset (float) – A value added to the weights before normalization. Typically 1 for Gemma-like normalization and 0 otherwise.
- multiply_before_cast (bool) – Whether to multiply before or after casting to the output dtype. Typically True for Gemma-like normalization and False otherwise.
-
Returns:
-
A layer-normalized tensor with the same shape and type as x.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!