IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

GemmaRMSNorm

GemmaRMSNorm

class max.experimental.nn.norm.GemmaRMSNorm(dim, eps=1e-06)

source

Bases: RMSNorm

Gemma-style root mean square normalization.

Subclasses RMSNorm with two differences:

  • Scales by 1 + weight rather than weight.
  • Multiplies by the scale before casting back to the input dtype, instead of after.

The constructor signature is identical to RMSNorm. Used by the Gemma model family.

Parameters:

forward()

forward(x)

source

Returns x normalized using the Gemma-style RMS variant.

Parameters:

x (Tensor)

Return type:

Tensor