Python class
GemmaRMSNorm
GemmaRMSNorm
class max.nn.norm.rms_norm.GemmaRMSNorm(dim, eps=1e-06)
Computes the Root Mean Square normalization on inputs.
Differences to traditional RMSNorm:
- x * (1 + w) instead of x * w.
- (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
forward()
forward(x)
Applies Gemma-style RMS normalization to the input.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!