Skip to main content

function

reduce

reduce[reduce_fn: fn[DType, DType, Int](SIMD[$0, $2], SIMD[$1, $2], /) capturing -> SIMD[$0, $2]](src: Buffer[type, size, address_space], init: SIMD[type, 1]) -> SIMD[$4, 1]

Computes a custom reduction of buffer elements.

Parameters:

  • reduce_fn (fn[DType, DType, Int](SIMD[$0, $2], SIMD[$1, $2], /) capturing -> SIMD[$0, $2]): The lambda implementing the reduction.

Args:

  • src (Buffer[type, size, address_space]): The input buffer.
  • init (SIMD[type, 1]): The initial value to use in accumulator.

Returns:

The computed reduction value.

reduce[map_fn: fn[DType, DType, Int](SIMD[$0, $2], SIMD[$1, $2], /) capturing -> SIMD[$0, $2], reduce_fn: fn[DType, Int](SIMD[$0, $1], /) -> SIMD[$0, 1], reduce_axis: Int](src: NDBuffer[type, rank, shape, address_space], dst: NDBuffer[type, rank, shape, address_space], init: SIMD[type, 1])

Performs a reduction across reduce_axis of an NDBuffer (src) and stores the result in an NDBuffer (dst).

First src is reshaped into a 3D tensor. Without loss of generality, the three axes will be referred to as [H,W,C], where the axis to reduce across is W, the axes before the reduce axis are packed into H, and the axes after the reduce axis are packed into C. i.e. a tensor with dims [D1, D2, ..., Di, ..., Dn] reducing across axis i gets packed into a 3D tensor with dims [H, W, C], where H=prod(D1,...,Di-1), W = Di, and C = prod(Di+1,...,Dn).

Parameters:

  • map_fn (fn[DType, DType, Int](SIMD[$0, $2], SIMD[$1, $2], /) capturing -> SIMD[$0, $2]): A mapping function. This function is used when to combine (accumulate) two chunks of input data: e.g. we load two 8xfloat32 vectors of elements and need to reduce them to a single 8xfloat32 vector.
  • reduce_fn (fn[DType, Int](SIMD[$0, $1], /) -> SIMD[$0, 1]): A reduction function. This function is used to reduce a vector to a scalar. E.g. when we got 8xfloat32 vector and want to reduce it to 1xfloat32.
  • reduce_axis (Int): The axis to reduce across.

Args:

  • src (NDBuffer[type, rank, shape, address_space]): The input buffer.
  • dst (NDBuffer[type, rank, shape, address_space]): The output buffer.
  • init (SIMD[type, 1]): The initial value to use in accumulator.