Skip to main content

Python class

GemmaRMSNorm

GemmaRMSNorm​

class max.nn.module_v3.norm.GemmaRMSNorm(dim, eps=1e-06)

Computes the Root Mean Square normalization on inputs.

Differences to traditional RMSNorm:

  • x * (1 + w) instead of x * w.
  • (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.

Constructs RMSNorm.

Parameters:

  • dim (int) – Size of last dimension of the expected input.
  • eps (float) – Value added to denominator for numerical stability.

forward()​

forward(x)

Applies Gemma-style RMS normalization to the input.

Parameters:

x (Tensor)

Return type:

Tensor

Was this page helpful?