Python class
GroupNorm
GroupNorm
class max.nn.module_v3.norm.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)
Group normalization block.
Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group_norm.
This implementation uses Tensor instead of Weight, which automatically handles dtype matching with input tensors, eliminating the need for dtype workarounds.
Example:
from max.nn.module_v3 import GroupNorm
from max.experimental.tensor import Tensor
norm = GroupNorm(num_groups=32, num_channels=128)
x = Tensor.ones([1, 128, 32, 32])
result = norm(x)Initialize GroupNorm module.
-
Parameters:
bias
The bias tensor with shape [num_channels] (None if affine=False).
eps
eps: float
Small constant added to denominator for numerical stability.
forward()
forward(x)
Apply group normalization to input tensor.
num_channels
num_channels: int
Number of input channels.
num_groups
num_groups: int
Number of groups to separate the channels into.
weight
The weight tensor with shape [num_channels] (None if affine=False).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!