Python module
rms_norm
Normalization layer.
DistributedRMSNorm
class max.nn.norm.rms_norm.DistributedRMSNorm(*args, devices: list[max.graph.type.DeviceRef], **kwargs)
RMSNorm
class max.nn.norm.rms_norm.RMSNorm(dim: int, eps: float = 1e-06, weight_offset: float = 0.0)
Computes the Root Mean Square normalization on inputs.
-
Parameters:
- dim – Size of last dimension of the expected input.
- eps – Value added to denominator for numerical stability.
- weight_offset – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0.
RMSNormV1
class max.nn.norm.rms_norm.RMSNormV1(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, eps: float = 1e-06, weight_offset: float = 0.0)
Computes the Root Mean Square normalization on inputs.
Deprecated: Use RMSNorm instead.
eps
eps*: float* = 1e-06
weight
weight*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*
weight_offset
weight_offset*: float* = 0.0
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!