Skip to main content

Python class

RMSNorm

RMSNorm​

class max.experimental.nn.norm.RMSNorm(dim, eps=1e-06)

source

Bases: Module

Computes the Root Mean Square normalization on inputs.

Constructs RMSNorm.

Parameters:

  • dim (int) – Size of last dimension of the expected input.
  • eps (float) – Value added to denominator for numerical stability.

dim​

property dim: Dim

source

Returns the embedding dimension.

eps​

eps: float

source

forward()​

forward(x)

source

Applies RMS normalization to the input.

Parameters:

x (Tensor)

Return type:

Tensor

weight​

weight: Tensor

source