Python module
norm
ConstantLayerNorm
class max.nn.norm.ConstantLayerNorm(dims, device, dtype, eps=1e-05)
Layer normalization block with constant gamma and beta values.
beta
beta: npt.NDArray[np.floating[Any]]
device
device: DeviceRef
dtype
dtype: DType
eps
eps: float = 1e-05
gamma
gamma: npt.NDArray[np.floating[Any]]
GroupNorm
class max.nn.norm.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=gpu:0)
Group normalization block.
Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group_norm.
LayerNorm
class max.nn.norm.LayerNorm(dims, devices, dtype, eps=1e-05, use_bias=True)
Layer normalization block.
shard()
shard(devices)
Creates sharded views of this LayerNorm across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the LayerNorm sharding strategy.
RMSNorm
class max.nn.norm.RMSNorm(dim, dtype, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)
Computes the Root Mean Square normalization on inputs.
-
Parameters:
-
- dim (int) – Size of last dimension of the expected input.
- eps (float) – Value added to denominator for numerical stability.
- weight_offset (float) – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0.
- multiply_before_cast (bool) – True if we multiply the inputs by the learned weights before casting to the input type (Gemma3-style). False if we cast the inputs to the input type first, then multiply by the learned weights (Llama-style).
- dtype (DType)
shard()
shard(devices)
Creates sharded views of this RMSNorm across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the RMSNorm sharding strategy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!