Python class
GemmaRMSNorm
GemmaRMSNormβ
class max.nn.module_v3.norm.GemmaRMSNorm(dim, eps=1e-06)
Computes the Root Mean Square normalization on inputs.
Differences to traditional RMSNorm:
- x * (1 + w) instead of x * w.
- (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
Constructs RMSNorm.
-
Parameters:
forward()β
forward(x)
Applies Gemma-style RMS normalization to the input.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!