Python class
GemmaRMSNorm
GemmaRMSNorm
class max.nn.norm.rms_norm.GemmaRMSNorm(dim, eps=1e-06)
Computes the Root Mean Square normalization on inputs.
Differences to traditional RMSNorm:
- x * (1 + w) instead of x * w.
- (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
forward()
forward(x)
Defines the computation performed by the module.
Users must override this method in their subclass to define the module’s computation.
-
Parameters:
-
- *args – Positional arguments for the computation.
- **kwargs – Keyword arguments for the computation.
- x (Tensor)
-
Returns:
-
The result of applying the module to the input.
-
Raises:
-
NotImplementedError – If the subclass does not override this method.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!