Skip to main content

Mojo function

reduce_launch

reduce_launch[num_reductions: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], rank: Int, dtype: DType](shape: IndexList[rank], axis: Int, init: StaticTuple[Scalar[dtype], num_reductions], ctx: DeviceContext)

Selects and launches the appropriate GPU reduction kernel based on the tensor shape, axis, and device saturation level. Dispatches to saturated_reduce_kernel for non-contiguous axes with enough rows, small_reduce_kernel for rows smaller than the warp size, or reduce_kernel otherwise.

Parameters:

  • num_reductions (Int): The number of fused reductions to perform.
  • input_fn (fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements.
  • output_fn (fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements.
  • reduce_fn (fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function.
  • rank (Int): The tensor rank.
  • dtype (DType): The data type of the elements.

Args:

  • shape (IndexList): The shape of the input tensor.
  • axis (Int): The axis along which to reduce.
  • init (StaticTuple): The identity values for each reduction.
  • ctx (DeviceContext): The device context for GPU execution.

Raises:

If the GPU kernel launch fails.

Was this page helpful?