Skip to main content

Mojo function

block_reduce

block_reduce[BLOCK_SIZE: Int, reduce_fn: fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], dtype: DType, simd_width: Int](val: SIMD[dtype, simd_width], init: Scalar[dtype]) -> Scalar[dtype]

Performs a block-level reduction of a single SIMD value across all threads in a GPU thread block using warp-level primitives and shared memory.

Parameters:

  • BLOCK_SIZE (Int): The number of threads per block.
  • reduce_fn (fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]): The binary reduction function.
  • dtype (DType): The data type of the elements.
  • simd_width (Int): The SIMD vector width.

Args:

  • val (SIMD): The per-thread SIMD value to reduce.
  • init (Scalar): The identity value for the reduction.

Returns:

Scalar: The reduced scalar result (valid on thread 0).

block_reduce[BLOCK_SIZE: Int, num_reductions: Int, reduce_fn: fn[dtype: DType, width: Int, reduction_idx: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], dtype: DType, simd_width: Int](val: StaticTuple[SIMD[dtype, simd_width], num_reductions], init: StaticTuple[Scalar[dtype], num_reductions]) -> StaticTuple[Scalar[dtype], num_reductions]

Performs a block-level reduction of multiple fused SIMD values across all threads in a GPU thread block using warp shuffles and shared memory.

Parameters:

  • BLOCK_SIZE (Int): The number of threads per block.
  • num_reductions (Int): The number of fused reductions to perform.
  • reduce_fn (fn[dtype: DType, width: Int, reduction_idx: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]): The binary reduction function, parameterized by reduction index.
  • dtype (DType): The data type of the elements.
  • simd_width (Int): The SIMD vector width.

Args:

  • val (StaticTuple): The per-thread SIMD values to reduce, one per reduction.
  • init (StaticTuple): The identity values for each reduction.

Returns:

StaticTuple: The reduced scalar results (valid on thread 0).

Was this page helpful?