Mojo function
radix_sort_pairs_kernel
radix_sort_pairs_kernel[dtype: DType, out_idx_type: DType, current_bit: Int, ascending: Bool = False, BLOCK_SIZE: Int = 256, NUM_BITS_PER_PASS: Int = 4](input_keys_: UnsafePointer[Scalar[dtype]], output_keys_: UnsafePointer[Scalar[dtype]], input_key_ids_: UnsafePointer[Scalar[out_idx_type]], output_key_ids_: UnsafePointer[Scalar[out_idx_type]], num_keys: Int, skip_sort: UnsafePointer[Scalar[DType.bool]])
Radix pair sort kernel for (default) descending order.
Implementation based on: AMD. Introduction to GPU Radix Sort. GPUOpen, 2017. Available at: https://gpuopen.com/download/publications/Introduction_to_GPU_Radix_Sort.pdf.
Parameters:
- dtype (DType): DType - Data type.
- out_idx_type (DType): DType - Output index type.
- current_bit (Int): Int - Current bit to start sorting NUM_BITS_PER_PASS bits at.
- ascending (Bool): Bool - Whether to sort in ascending order.
- BLOCK_SIZE (Int): Int - Block size.
- NUM_BITS_PER_PASS (Int): Int - Number of bits per pass.
Args:
- input_keys_ (UnsafePointer): Input tensor values to sort.
- output_keys_ (UnsafePointer): Output tensor values sorted in (default) descending order.
- input_key_ids_ (UnsafePointer): Input tensor indices.
- output_key_ids_ (UnsafePointer): Output tensor indices sorted in (default) descending order.
- num_keys (Int): Number of keys to sort per batch.
- skip_sort (UnsafePointer): Whether sorting is skipped for this batch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!
