Skip to main content

Python class

GemmaRMSNorm

GemmaRMSNorm

class max.nn.norm.rms_norm.GemmaRMSNorm(dim, eps=1e-06)

Computes the Root Mean Square normalization on inputs.

Differences to traditional RMSNorm:

  • x * (1 + w) instead of x * w.
  • (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.

Parameters:

forward()

forward(x)

Defines the computation performed by the module.

Users must override this method in their subclass to define the module’s computation.

Parameters:

  • *args – Positional arguments for the computation.
  • **kwargs – Keyword arguments for the computation.
  • x (Tensor)

Returns:

The result of applying the module to the input.

Raises:

NotImplementedError – If the subclass does not override this method.

Return type:

Tensor

Was this page helpful?