Skip to main content

Python module

layer_norm

Layer Normalization layer.

ConstantLayerNorm

class max.nn.norm.layer_norm.ConstantLayerNorm(dims, device, dtype, eps=1e-05)

Layer normalization block with constant gamma and beta values.

Parameters:

beta

beta: npt.NDArray[np.floating[Any]]

device

device: DeviceRef

dtype

dtype: DType

eps

eps: float = 1e-05

gamma

gamma: npt.NDArray[np.floating[Any]]

LayerNorm

class max.nn.norm.layer_norm.LayerNorm(dims, devices, dtype, eps=1e-05, use_bias=True)

Layer normalization block.

Parameters:

shard()

shard(devices)

Creates sharded views of this LayerNorm across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded LayerNorm instances, one for each device.

Return type:

Sequence[LayerNorm]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the LayerNorm sharding strategy.

Was this page helpful?