Skip to main content

Python class

RMSNorm

RMSNorm

class max.nn.RMSNorm(dim, dtype, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)

source

Bases: Module, Shardable

Computes the root mean square normalization on inputs.

When called, RMSNorm normalizes the input using only the root mean square statistic, without centering by the mean. It accepts a TensorValue of shape (..., dim) and returns a normalized TensorValue of the same shape.

This is more efficient than LayerNorm while maintaining comparable performance in transformer models. For more information, see Root Mean Square Layer Normalization.

Parameters:

  • dim (int) – Size of last dimension of the expected input.
  • eps (float) – Value added to denominator for numerical stability.
  • weight_offset (float) – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0.
  • multiply_before_cast (bool) – True if we multiply the inputs by the learned weights before casting to the input type (Gemma3-style). False if we cast the inputs to the input type first, then multiply by the learned weights (Llama-style).
  • dtype (DType)

shard()

shard(devices)

source

Creates sharded views of this RMSNorm across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded RMSNorm instances, one for each device.

Return type:

Sequence[RMSNorm]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the RMSNorm sharding strategy.