Skip to main content
Log in

Python module

group_norm

Group Normalization implementation using the graph API.

GroupNorm

class max.nn.norm.group_norm.GroupNorm(num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True, device: ~max.graph.type.DeviceRef = cpu:0)

Group normalization block.

Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group_norm.

  • Parameters:

    • num_groups – Number of groups to separate the channels into
    • num_channels – Number of input channels
    • eps – Small constant added to denominator for numerical stability
    • affine – If True, apply learnable affine transform parameters