Skip to main content

Python class

GroupNorm

GroupNorm

class max.nn.module_v3.norm.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)

Group normalization block.

Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group_norm.

This implementation uses Tensor instead of Weight, which automatically handles dtype matching with input tensors, eliminating the need for dtype workarounds.

Example:

from max.nn.module_v3 import GroupNorm
from max.experimental.tensor import Tensor

norm = GroupNorm(num_groups=32, num_channels=128)
x = Tensor.ones([1, 128, 32, 32])
result = norm(x)

Initialize GroupNorm module.

Parameters:

  • num_groups (int) – Number of groups to separate the channels into
  • num_channels (int) – Number of input channels
  • eps (float) – Small constant added to denominator for numerical stability. Default: 1e-5
  • affine (bool) – If True, apply learnable affine transform parameters. Default: True

bias

bias: Tensor | None

The bias tensor with shape [num_channels] (None if affine=False).

eps

eps: float

Small constant added to denominator for numerical stability.

forward()

forward(x)

Apply group normalization to input tensor.

Parameters:

x (Tensor) – Input tensor of shape [N, C, *] where C is number of channels

Returns:

Normalized tensor of same shape as input

Return type:

Tensor

num_channels

num_channels: int

Number of input channels.

num_groups

num_groups: int

Number of groups to separate the channels into.

weight

weight: Tensor | None

The weight tensor with shape [num_channels] (None if affine=False).

Was this page helpful?