Python class
LayerNorm
LayerNorm
class max.nn.LayerNorm(dims, devices, dtype, eps=1e-05, use_bias=True)
Layer normalization over the last dimension.
When called, LayerNorm accepts a TensorValue of shape
(..., dims) and returns a normalized TensorValue of
the same shape. More specifically, it normalizes inputs across the feature
dimension by computing the mean and variance for each sample independently.
The computation is: where is the mean, is the standard deviation, and are learned affine parameters.
-
Parameters:
-
- dims (int) – The size of the feature dimension to normalize over.
- devices (Sequence[DeviceRef]) – The target
DeviceRefinstances for computation. - dtype (DType) – The
DTypefor the layer. - eps (float) – A small value added to the denominator for numerical stability.
- use_bias (bool) – Whether to include a learnable bias term (
beta).
shard()
shard(devices)
Creates sharded views of this LayerNorm across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the LayerNorm sharding strategy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!