For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python class
GroupNorm
GroupNormβ
class max.experimental.nn.norm.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)
Bases: Module
Group normalization over the channel axis of the input.
The input is expected to have shape (batch, channels, ...) where
... is any number of trailing axes (typically spatial dimensions
for convolutional features). The channel axis is split into
num_groups groups and each group is normalized independently.
Useful when the batch axis is small enough that batch normalization
is unstable.
For example:
from max.dtype import DType
from max.experimental.nn.norm import GroupNorm
from max.experimental.realization_context import (
GraphRealizationContext,
realization_context,
)
from max.experimental.tensor import Tensor
from max.graph import DeviceRef, Graph, TensorType
graph = Graph(
"gn",
input_types=[
TensorType(DType.float32, ("batch", 128, 32, 32), DeviceRef.GPU()),
],
)
ctx = GraphRealizationContext(graph)
with realization_context(ctx), ctx:
x = Tensor.from_graph_value(graph.inputs[0])
norm = GroupNorm(num_groups=32, num_channels=128)
y = norm(x)
graph.output(y)-
Parameters:
-
- num_groups (int) β The number of groups to split the channel axis into.
Must divide
num_channelsevenly. - num_channels (int) β The size of the channel axis of the input (axis 1).
- eps (float) β A small positive constant added to the variance for numerical
stability. Defaults to
1e-5. - affine (bool) β Whether to learn a per-channel scale and bias. When
False, no parameters are created. Defaults toTrue.
- num_groups (int) β The number of groups to split the channel axis into.
Must divide
-
Raises:
-
ValueError β If
num_channelsis not divisible bynum_groups.
biasβ
The learned per-channel bias of shape [num_channels], or
None when affine is False.
epsβ
eps: float
The variance epsilon used for numerical stability.
forward()β
forward(x)
Returns x normalized within each channel group.
-
Parameters:
-
x (Tensor) β The input tensor of shape
(batch, channels, ...). The size of the channel axis must equalnum_channels. -
Returns:
-
A tensor with the same shape and dtype as
x. -
Raises:
-
ValueError β If
xhas fewer than 2 dimensions, or if its channel axis does not matchnum_channels. -
Return type:
num_channelsβ
num_channels: int
The size of the channel axis of the input.
num_groupsβ
num_groups: int
The number of groups the channel axis is split into.
weightβ
The learned per-channel scale of shape [num_channels], or
None when affine is False.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!