Skip to main content

Mojo function

twophase_reduce_kernel

twophase_reduce_kernel[rank: Int, axis: Int, num_reductions: Int, BLOCK_SIZE: Int, input_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], dtype: DType, simd_width: Int, accum_type: DType = get_accum_type[dtype]()](shape: IndexList[rank], init: StaticTuple[Scalar[dtype], num_reductions], partials: UnsafePointer[Scalar[accum_type], MutAnyOrigin], counters: UnsafePointer[Int32, MutAnyOrigin], blocks_per_row: Int)

GPU kernel for reductions when there are too few rows to saturate the device at one block per row. Assigns multiple blocks per row and uses a two-phase approach: each block reduces a chunk via cooperative block-level reduction, then the last block to finish (detected via a per-row atomic counter) reduces all partial results for its row.

Parameters:

  • rank (Int): The tensor rank.
  • axis (Int): The axis along which to reduce.
  • num_reductions (Int): The number of fused reductions to perform.
  • BLOCK_SIZE (Int): The number of threads per block.
  • input_fn (def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements.
  • output_fn (def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements.
  • reduce_fn (def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function.
  • dtype (DType): The data type of the elements.
  • simd_width (Int): The SIMD vector width.
  • accum_type (DType): The accumulator data type.

Args:

  • shape (IndexList): The shape of the input tensor.
  • init (StaticTuple): The identity values for each reduction.
  • partials (UnsafePointer): Global memory buffer for per-block partial results. Size: grid_dim.x * num_reductions elements of accum_type.
  • counters (UnsafePointer): Global memory buffer for per-row atomic completion counters. Size: num_rows elements of int32, zero-initialized.
  • blocks_per_row (Int): The number of blocks assigned to each row.

Was this page helpful?