Python module
rms_norm
Normalization layer.
DistributedRMSNorm
class max.nn.norm.rms_norm.DistributedRMSNorm(*args, devices, **kwargs)
-
Parameters:
-
devices (
list
[
DeviceRef
]
)
RMSNorm
class max.nn.norm.rms_norm.RMSNorm(dim, dtype, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)
Computes the Root Mean Square normalization on inputs.
-
Parameters:
-
- dim (
int
) – Size of last dimension of the expected input. - eps (
float
) – Value added to denominator for numerical stability. - weight_offset (
float
) – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0. - multiply_before_cast (
bool
) – True if we multiply the inputs by the learned weights before casting to the input type (Gemma3-style). False if we cast the inputs to the input type first, then multiply by the learned weights (Llama-style). - dtype (
DType
)
- dim (
RMSNormV1
class max.nn.norm.rms_norm.RMSNormV1(weight, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)
Computes the Root Mean Square normalization on inputs.
Deprecated: Use RMSNorm instead.
-
Parameters:
eps
eps*: float* = 1e-06
multiply_before_cast
multiply_before_cast*: bool* = True
weight
weight*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*
weight_offset
weight_offset*: float* = 0.0
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!