Mojo function
reduce_launch
reduce_launch[num_reductions: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], rank: Int, dtype: DType](shape: IndexList[rank], axis: Int, init: StaticTuple[Scalar[dtype], num_reductions], ctx: DeviceContext)
Selects and launches the appropriate GPU reduction kernel based on the tensor shape, axis, and device saturation level. Dispatches to saturated_reduce_kernel for non-contiguous axes with enough rows, small_reduce_kernel for rows smaller than the warp size, or reduce_kernel otherwise.
Parameters:
- num_reductions (
Int): The number of fused reductions to perform. - input_fn (
fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements. - output_fn (
fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements. - reduce_fn (
fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function. - rank (
Int): The tensor rank. - dtype (
DType): The data type of the elements.
Args:
- shape (
IndexList): The shape of the input tensor. - axis (
Int): The axis along which to reduce. - init (
StaticTuple): The identity values for each reduction. - ctx (
DeviceContext): The device context for GPU execution.
Raises:
If the GPU kernel launch fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!