Mojo function
twophase_reduce_kernel
twophase_reduce_kernel[rank: Int, axis: Int, num_reductions: Int, BLOCK_SIZE: Int, input_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], dtype: DType, simd_width: Int, accum_type: DType = get_accum_type[dtype]()](shape: IndexList[rank], init: StaticTuple[Scalar[dtype], num_reductions], partials: UnsafePointer[Scalar[accum_type], MutAnyOrigin], counters: UnsafePointer[Int32, MutAnyOrigin], blocks_per_row: Int)
GPU kernel for reductions when there are too few rows to saturate the device at one block per row. Assigns multiple blocks per row and uses a two-phase approach: each block reduces a chunk via cooperative block-level reduction, then the last block to finish (detected via a per-row atomic counter) reduces all partial results for its row.
Parameters:
- rank (
Int): The tensor rank. - axis (
Int): The axis along which to reduce. - num_reductions (
Int): The number of fused reductions to perform. - BLOCK_SIZE (
Int): The number of threads per block. - input_fn (
def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements. - output_fn (
def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements. - reduce_fn (
def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function. - dtype (
DType): The data type of the elements. - simd_width (
Int): The SIMD vector width. - accum_type (
DType): The accumulator data type.
Args:
- shape (
IndexList): The shape of the input tensor. - init (
StaticTuple): The identity values for each reduction. - partials (
UnsafePointer): Global memory buffer for per-block partial results. Size: grid_dim.x * num_reductions elements of accum_type. - counters (
UnsafePointer): Global memory buffer for per-row atomic completion counters. Size: num_rows elements of int32, zero-initialized. - blocks_per_row (
Int): The number of blocks assigned to each row.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!