IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

group_norm_reshape

def group_norm_reshape[dtype: DType, rank: Int](shape: IndexList[rank, element_type=shape.element_type], buf: TileTensor[dtype, Storage=buf.Storage, address_space=buf.address_space, linear_idx_type=buf.linear_idx_type, element_size=buf.element_size], channels_per_group: Int, spatial: Int) -> TileTensor[dtype, Layout[*?, *?], buf.origin, address_space=buf.address_space]

Reshapes an input buffer for group normalization by flattening all dimensions except the group dimension. Returns a 2D buffer of shape (num_groups * N, group_size), where group_size is the product of channels_per_group and spatial.

Returns:

TileTensor[dtype, Layout[*?, *?], buf.origin, address_space=buf.address_space]