Skip to main content

Python module

rms_norm

Root mean square layer normalization.

RMSNorm

class max.nn.norm.rms_norm.RMSNorm(dim, eps=1e-06)

Computes the Root Mean Square normalization on inputs.

Parameters:

dim

property dim: Dim

eps

eps: float

forward()

forward(x)

Defines the computation performed by the module.

Users must override this method in their subclass to define the module’s computation.

Parameters:

  • *args – Positional arguments for the computation.
  • **kwargs – Keyword arguments for the computation.
  • x (Tensor)

Returns:

The result of applying the module to the input.

Raises:

NotImplementedError – If the subclass does not override this method.

Return type:

Tensor

weight

weight: Tensor

rms_norm()

max.nn.norm.rms_norm.rms_norm(x, weight, eps, weight_offset=0.0, multiply_before_cast=False)

Applies Root Mean Square layer normalization to an input tensor.

See https://arxiv.org/abs/1910.07467

Parameters:

  • x (Tensor) – The input tensor
  • weight (Tensor) – The weights for the normalization
  • eps (float) – A value added to the denominator of the normalization for numerical stability
  • weight_offset (float) – A value added to the weights before normalization. Typically 1 for Gemma-like normalization and 0 otherwise.
  • multiply_before_cast (bool) – Whether to multiply before or after casting to the output dtype. Typically True for Gemma-like normalization and False otherwise.

Returns:

A layer-normalized tensor with the same shape and type as x.

Return type:

Tensor

Was this page helpful?