Skip to main content

Mojo function

reduce_kernel

reduce_kernel[rank: Int, axis: Int, num_reductions: Int, BLOCK_SIZE: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], dtype: DType, simd_width: Int, accum_type: DType = get_accum_type[dtype]()](shape: IndexList[rank], init: StaticTuple[Scalar[dtype], num_reductions])

GPU kernel that reduces rows along a given axis. Each block reduces one row at a time using row_reduce and writes the result via output_fn. Uses a grid-stride loop to handle more rows than blocks.

Parameters:

  • rank (Int): The tensor rank.
  • axis (Int): The axis along which to reduce.
  • num_reductions (Int): The number of fused reductions to perform.
  • BLOCK_SIZE (Int): The number of threads per block.
  • input_fn (fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements.
  • output_fn (fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements.
  • reduce_fn (fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function.
  • dtype (DType): The data type of the elements.
  • simd_width (Int): The SIMD vector width.
  • accum_type (DType): The accumulator data type.

Args:

  • shape (IndexList): The shape of the input tensor.
  • init (StaticTuple): The identity values for each reduction.

Was this page helpful?