Mojo function
saturated_reduce_kernel
saturated_reduce_kernel[rank: Int, axis: Int, num_reductions: Int, BLOCK_SIZE: Int, input_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], dtype: DType, simd_width: Int, accum_type: DType = get_accum_type[dtype]()](shape: IndexList[rank], init: StaticTuple[Scalar[dtype], num_reductions])
GPU kernel for reductions when the device is saturated with enough rows. Each thread independently reduces an entire row using SIMD packing, avoiding shared-memory synchronization entirely. Used when reducing along a non-contiguous axis.
Parameters:
- rank (
Int): The tensor rank. - axis (
Int): The axis along which to reduce. - num_reductions (
Int): The number of fused reductions to perform. - BLOCK_SIZE (
Int): The number of threads per block. - input_fn (
fn[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements. - output_fn (
fn[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements. - reduce_fn (
fn[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function. - dtype (
DType): The data type of the elements. - simd_width (
Int): The SIMD vector width. - accum_type (
DType): The accumulator data type.
Args:
- shape (
IndexList): The shape of the input tensor. - init (
StaticTuple): The identity values for each reduction.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!