Skip to main content

Python class

LayerNorm

LayerNorm

class max.nn.LayerNorm(dims, devices, dtype, eps=1e-05, use_bias=True)

source

Bases: Module, Shardable

Layer normalization over the last dimension.

When called, LayerNorm accepts a TensorValue of shape (..., dims) and returns a normalized TensorValue of the same shape. More specifically, it normalizes inputs across the feature dimension by computing the mean and variance for each sample independently.

The computation is: output=γxμσ+β\text{output} = \gamma \cdot \frac{x - \mu}{\sigma} + \beta where μ\mu is the mean, σ=var(x)+ϵ\sigma = \sqrt{\text{var}(x) + \epsilon} is the standard deviation, and γ,β\gamma, \beta are learned affine parameters.

Parameters:

  • dims (int) – The size of the feature dimension to normalize over.
  • devices (Sequence[DeviceRef]) – The target DeviceRef instances for computation.
  • dtype (DType) – The DType for the layer.
  • eps (float) – A small value added to the denominator for numerical stability.
  • use_bias (bool) – Whether to include a learnable bias term (beta).

shard()

shard(devices)

source

Creates sharded views of this LayerNorm across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded LayerNorm instances, one for each device.

Return type:

Sequence[LayerNorm]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the LayerNorm sharding strategy.