Mojo function
row_reduce
row_reduce[BLOCK_SIZE: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], reduce_fn: fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], dtype: DType, simd_width: Int, rank: Int, accum_type: DType = get_accum_type[dtype]()](mut row_coords: IndexList[rank], axis: Int, init: Scalar[dtype], row_size: Int) -> Scalar[accum_type]
Reduces a single row along the given axis using block-level cooperative reduction. Delegates to the multi-reduction row_reduce overload with num_reductions=1.
Parameters:
- BLOCK_SIZE (
Int): The number of threads per block. - input_fn (
fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements. - reduce_fn (
fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]): The binary reduction function. - dtype (
DType): The data type of the input elements. - simd_width (
Int): The SIMD vector width. - rank (
Int): The tensor rank. - accum_type (
DType): The accumulator data type (defaults to widened type).
Args:
- row_coords (
IndexList): The ND coordinates identifying the row. - axis (
Int): The axis along which to reduce. - init (
Scalar): The identity value for the reduction. - row_size (
Int): The number of elements in the row.
Returns:
Scalar: The reduced scalar result for the row.
row_reduce[BLOCK_SIZE: Int, num_reductions: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], reduce_fn: fn[dtype: DType, width: Int, reduction_idx: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], dtype: DType, simd_width: Int, rank: Int, accum_type: DType = get_accum_type[dtype]()](mut row_coords: IndexList[rank], axis: Int, init: StaticTuple[Scalar[dtype], num_reductions], row_size: Int) -> StaticTuple[Scalar[accum_type], num_reductions]
Reduces a row along the given axis with multiple fused reductions using cooperative SIMD-width reads across threads, a block_reduce, and scalar tail handling.
Parameters:
- BLOCK_SIZE (
Int): The number of threads per block. - num_reductions (
Int): The number of fused reductions to perform. - input_fn (
fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements. - reduce_fn (
fn[dtype: DType, width: Int, reduction_idx: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]): The binary reduction function parameterized by reduction index. - dtype (
DType): The data type of the input elements. - simd_width (
Int): The SIMD vector width. - rank (
Int): The tensor rank. - accum_type (
DType): The accumulator data type (defaults to widened type).
Args:
- row_coords (
IndexList): The ND coordinates identifying the row. - axis (
Int): The axis along which to reduce. - init (
StaticTuple): The identity values for each reduction. - row_size (
Int): The number of elements in the row.
Returns:
StaticTuple: The reduced scalar results, one per fused reduction.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!