Skip to main content
Log in

Python module

rms_norm

Normalization layer.

DistributedRMSNorm

class max.nn.norm.rms_norm.DistributedRMSNorm(*args, devices, **kwargs)

Parameters:

devices (list [ DeviceRef ] )

RMSNorm

class max.nn.norm.rms_norm.RMSNorm(dim, dtype, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)

Computes the Root Mean Square normalization on inputs.

Parameters:

  • dim (int ) – Size of last dimension of the expected input.
  • eps (float ) – Value added to denominator for numerical stability.
  • weight_offset (float ) – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0.
  • multiply_before_cast (bool ) – True if we multiply the inputs by the learned weights before casting to the input type (Gemma3-style). False if we cast the inputs to the input type first, then multiply by the learned weights (Llama-style).
  • dtype (DType )

RMSNormV1

class max.nn.norm.rms_norm.RMSNormV1(weight, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)

Computes the Root Mean Square normalization on inputs.

Deprecated: Use RMSNorm instead.

Parameters:

eps

eps*: float* = 1e-06

multiply_before_cast

multiply_before_cast*: bool* = True

weight

weight*: Value[TensorType] | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*

weight_offset

weight_offset*: float* = 0.0