IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

layer_norm_cpu

def layer_norm_cpu[dtype: DType, //, input_fn: def[width: Int, alignment: Int](Int, Int) capturing -> SIMD[dtype, width], gamma_fn: def[width: Int, rank: Int, alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: def[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None](num_rows: Int, num_cols: Int, beta: TileTensor[dtype, address_space=beta.address_space, linear_idx_type=beta.linear_idx_type, element_size=beta.element_size], epsilon: Scalar[dtype])

Computes layernorm(elementwise_fn(x)) across the last dimension of x, where layernorm is defined as (xβˆ’mean(x))/(sqrt(var(x)+eps)βˆ—gammafn+beta(x-mean(x))/(sqrt(var(x)+eps)*gamma_fn + beta.

Currently performs 3 passes over the input data. This can be reduced to 2 by fusing the add, mean, and variance loops using Welford's algorithm.

Parameters:

  • ​dtype (DType): The x and out buffers' elements dtype.
  • ​input_fn (def[width: Int, alignment: Int](Int, Int) capturing -> SIMD[dtype, width]): Function called to generate an input value.
  • ​gamma_fn (def[width: Int, rank: Int, alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): Function called to generate a gamma value.
  • ​output_fn (def[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None): Function called to store the output value.

Args:

def layer_norm_cpu[dtype: DType, rank: Int, //, input_fn: def[width: Int, rank: Int, alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width], gamma_fn: def[width: Int, rank: Int, alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: def[width: Int, rank: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None](shape: IndexList[rank], beta: TileTensor[dtype, address_space=beta.address_space, linear_idx_type=beta.linear_idx_type, element_size=beta.element_size], epsilon: Scalar[dtype], ctx: Optional[DeviceContext] = None)