Python module
layer_norm
Layer Normalization layer.
ConstantLayerNorm
class max.nn.norm.layer_norm.ConstantLayerNorm(dims, device, dtype, eps=1e-05)
Layer normalization block with constant gamma and beta values.
beta
beta: npt.NDArray[np.floating[Any]]
device
device: DeviceRef
dtype
dtype: DType
eps
eps: float = 1e-05
gamma
gamma: npt.NDArray[np.floating[Any]]
LayerNorm
class max.nn.norm.layer_norm.LayerNorm(dims, devices, dtype, eps=1e-05, use_bias=True)
Layer normalization block.
shard()
shard(devices)
Creates sharded views of this LayerNorm across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the LayerNorm sharding strategy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!