IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

GroupNorm

GroupNorm​

class max.experimental.nn.norm.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)

source

Bases: Module

Group normalization over the channel axis of the input.

The input is expected to have shape (batch, channels, ...) where ... is any number of trailing axes (typically spatial dimensions for convolutional features). The channel axis is split into num_groups groups and each group is normalized independently. Useful when the batch axis is small enough that batch normalization is unstable.

For example:

from max.dtype import DType
from max.experimental.nn.norm import GroupNorm
from max.experimental.realization_context import (
    GraphRealizationContext,
    realization_context,
)
from max.experimental.tensor import Tensor
from max.graph import DeviceRef, Graph, TensorType

graph = Graph(
    "gn",
    input_types=[
        TensorType(DType.float32, ("batch", 128, 32, 32), DeviceRef.GPU()),
    ],
)
ctx = GraphRealizationContext(graph)
with realization_context(ctx), ctx:
    x = Tensor.from_graph_value(graph.inputs[0])
    norm = GroupNorm(num_groups=32, num_channels=128)
    y = norm(x)
    graph.output(y)

Parameters:

  • num_groups (int) – The number of groups to split the channel axis into. Must divide num_channels evenly.
  • num_channels (int) – The size of the channel axis of the input (axis 1).
  • eps (float) – A small positive constant added to the variance for numerical stability. Defaults to 1e-5.
  • affine (bool) – Whether to learn a per-channel scale and bias. When False, no parameters are created. Defaults to True.

Raises:

ValueError – If num_channels is not divisible by num_groups.

bias​

bias: Tensor | None

source

The learned per-channel bias of shape [num_channels], or None when affine is False.

eps​

eps: float

source

The variance epsilon used for numerical stability.

forward()​

forward(x)

source

Returns x normalized within each channel group.

Parameters:

x (Tensor) – The input tensor of shape (batch, channels, ...). The size of the channel axis must equal num_channels.

Returns:

A tensor with the same shape and dtype as x.

Raises:

ValueError – If x has fewer than 2 dimensions, or if its channel axis does not match num_channels.

Return type:

Tensor

num_channels​

num_channels: int

source

The size of the channel axis of the input.

num_groups​

num_groups: int

source

The number of groups the channel axis is split into.

weight​

weight: Tensor | None

source

The learned per-channel scale of shape [num_channels], or None when affine is False.