Skip to main content

Mojo function

row_reduce

row_reduce[BLOCK_SIZE: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], reduce_fn: fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], dtype: DType, simd_width: Int, rank: Int, accum_type: DType = get_accum_type[dtype]()](mut row_coords: IndexList[rank], axis: Int, init: Scalar[dtype], row_size: Int) -> Scalar[accum_type]

Reduces a single row along the given axis using block-level cooperative reduction. Delegates to the multi-reduction row_reduce overload with num_reductions=1.

Parameters:

  • BLOCK_SIZE (Int): The number of threads per block.
  • input_fn (fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements.
  • reduce_fn (fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]): The binary reduction function.
  • dtype (DType): The data type of the input elements.
  • simd_width (Int): The SIMD vector width.
  • rank (Int): The tensor rank.
  • accum_type (DType): The accumulator data type (defaults to widened type).

Args:

  • row_coords (IndexList): The ND coordinates identifying the row.
  • axis (Int): The axis along which to reduce.
  • init (Scalar): The identity value for the reduction.
  • row_size (Int): The number of elements in the row.

Returns:

Scalar: The reduced scalar result for the row.

row_reduce[BLOCK_SIZE: Int, num_reductions: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], reduce_fn: fn[dtype: DType, width: Int, reduction_idx: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], dtype: DType, simd_width: Int, rank: Int, accum_type: DType = get_accum_type[dtype]()](mut row_coords: IndexList[rank], axis: Int, init: StaticTuple[Scalar[dtype], num_reductions], row_size: Int) -> StaticTuple[Scalar[accum_type], num_reductions]

Reduces a row along the given axis with multiple fused reductions using cooperative SIMD-width reads across threads, a block_reduce, and scalar tail handling.

Parameters:

  • BLOCK_SIZE (Int): The number of threads per block.
  • num_reductions (Int): The number of fused reductions to perform.
  • input_fn (fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements.
  • reduce_fn (fn[dtype: DType, width: Int, reduction_idx: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]): The binary reduction function parameterized by reduction index.
  • dtype (DType): The data type of the input elements.
  • simd_width (Int): The SIMD vector width.
  • rank (Int): The tensor rank.
  • accum_type (DType): The accumulator data type (defaults to widened type).

Args:

  • row_coords (IndexList): The ND coordinates identifying the row.
  • axis (Int): The axis along which to reduce.
  • init (StaticTuple): The identity values for each reduction.
  • row_size (Int): The number of elements in the row.

Returns:

StaticTuple: The reduced scalar results, one per fused reduction.

Was this page helpful?