Skip to main content

function

map_reduce

map_reduce[simd_width: Int, size: Dim, type: DType, acc_type: DType, input_gen_fn: fn[DType, Int](Int, /) capturing -> SIMD[$0, $1], reduce_vec_to_vec_fn: fn[DType, DType, Int](SIMD[$0, $2], SIMD[$1, $2], /) capturing -> SIMD[$0, $2], reduce_vec_to_scalar_fn: fn[DType, Int](SIMD[$0, $1], /) -> SIMD[$0, 1]](dst: Buffer[type, size, 0], init: SIMD[acc_type, 1]) -> SIMD[$3, 1]

Stores the result of calling input_gen_fn in dst and simultaneously reduce the result using a custom reduction function.

Parameters:

  • simd_width (Int): The vector width for the computation.
  • size (Dim): The buffer size.
  • type (DType): The buffer elements dtype.
  • acc_type (DType): The dtype of the reduction accumulator.
  • input_gen_fn (fn[DType, Int](Int, /) capturing -> SIMD[$0, $1]): A function that generates inputs to reduce.
  • reduce_vec_to_vec_fn (fn[DType, DType, Int](SIMD[$0, $2], SIMD[$1, $2], /) capturing -> SIMD[$0, $2]): A mapping function. This function is used to combine (accumulate) two chunks of input data: e.g. we load two 8xfloat32 vectors of elements and need to reduce them into a single 8xfloat32 vector.
  • reduce_vec_to_scalar_fn (fn[DType, Int](SIMD[$0, $1], /) -> SIMD[$0, 1]): A reduction function. This function is used to reduce a vector to a scalar. E.g. when we got 8xfloat32 vector and want to reduce it to an float32 scalar.

Args:

  • dst (Buffer[type, size, 0]): The output buffer.
  • init (SIMD[acc_type, 1]): The initial value to use in accumulator.

Returns:

The computed reduction value.