1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal.util;
18 
19 import android.annotation.NonNull;
20 import android.annotation.Nullable;
21 
22 import java.util.Comparator;
23 import java.util.List;
24 
25 /**
26  * An implementation of the quick selection algorithm as described in
27  * http://en.wikipedia.org/wiki/Quickselect.
28  *
29  * @hide
30  */
31 public final class QuickSelect {
selectImpl(@onNull List<T> list, int left, int right, int k, @NonNull Comparator<? super T> comparator)32     private static <T> int selectImpl(@NonNull List<T> list, int left, int right, int k,
33             @NonNull Comparator<? super T> comparator) {
34         while (true) {
35             if (left == right) {
36                 return left;
37             }
38             final int pivotIndex = partition(list, left, right, (left + right) >> 1, comparator);
39             if (k == pivotIndex) {
40                 return k;
41             } else if (k < pivotIndex) {
42                 right = pivotIndex - 1;
43             } else {
44                 left = pivotIndex + 1;
45             }
46         }
47     }
48 
selectImpl(@onNull int[] array, int left, int right, int k)49     private static int selectImpl(@NonNull int[] array, int left, int right, int k) {
50         while (true) {
51             if (left == right) {
52                 return left;
53             }
54             final int pivotIndex = partition(array, left, right, (left + right) >> 1);
55             if (k == pivotIndex) {
56                 return k;
57             } else if (k < pivotIndex) {
58                 right = pivotIndex - 1;
59             } else {
60                 left = pivotIndex + 1;
61             }
62         }
63     }
64 
selectImpl(@onNull long[] array, int left, int right, int k)65     private static int selectImpl(@NonNull long[] array, int left, int right, int k) {
66         while (true) {
67             if (left == right) {
68                 return left;
69             }
70             final int pivotIndex = partition(array, left, right, (left + right) >> 1);
71             if (k == pivotIndex) {
72                 return k;
73             } else if (k < pivotIndex) {
74                 right = pivotIndex - 1;
75             } else {
76                 left = pivotIndex + 1;
77             }
78         }
79     }
80 
selectImpl(@onNull T[] array, int left, int right, int k, @NonNull Comparator<? super T> comparator)81     private static <T> int selectImpl(@NonNull T[] array, int left, int right, int k,
82             @NonNull Comparator<? super T> comparator) {
83         while (true) {
84             if (left == right) {
85                 return left;
86             }
87             final int pivotIndex = partition(array, left, right, (left + right) >> 1, comparator);
88             if (k == pivotIndex) {
89                 return k;
90             } else if (k < pivotIndex) {
91                 right = pivotIndex - 1;
92             } else {
93                 left = pivotIndex + 1;
94             }
95         }
96     }
97 
partition(@onNull List<T> list, int left, int right, int pivotIndex, @NonNull Comparator<? super T> comparator)98     private static <T> int partition(@NonNull List<T> list, int left, int right, int pivotIndex,
99             @NonNull Comparator<? super T> comparator) {
100         final T pivotValue = list.get(pivotIndex);
101         swap(list, right, pivotIndex);
102         int storeIndex = left;
103         for (int i = left; i < right; i++) {
104             if (comparator.compare(list.get(i), pivotValue) < 0) {
105                 swap(list, storeIndex, i);
106                 storeIndex++;
107             }
108         }
109         swap(list, right, storeIndex);
110         return storeIndex;
111     }
112 
partition(@onNull int[] array, int left, int right, int pivotIndex)113     private static int partition(@NonNull int[] array, int left, int right, int pivotIndex) {
114         final int pivotValue = array[pivotIndex];
115         swap(array, right, pivotIndex);
116         int storeIndex = left;
117         for (int i = left; i < right; i++) {
118             if (array[i] < pivotValue) {
119                 swap(array, storeIndex, i);
120                 storeIndex++;
121             }
122         }
123         swap(array, right, storeIndex);
124         return storeIndex;
125     }
126 
partition(@onNull long[] array, int left, int right, int pivotIndex)127     private static int partition(@NonNull long[] array, int left, int right, int pivotIndex) {
128         final long pivotValue = array[pivotIndex];
129         swap(array, right, pivotIndex);
130         int storeIndex = left;
131         for (int i = left; i < right; i++) {
132             if (array[i] < pivotValue) {
133                 swap(array, storeIndex, i);
134                 storeIndex++;
135             }
136         }
137         swap(array, right, storeIndex);
138         return storeIndex;
139     }
140 
partition(@onNull T[] array, int left, int right, int pivotIndex, @NonNull Comparator<? super T> comparator)141     private static <T> int partition(@NonNull T[] array, int left, int right, int pivotIndex,
142             @NonNull Comparator<? super T> comparator) {
143         final T pivotValue = array[pivotIndex];
144         swap(array, right, pivotIndex);
145         int storeIndex = left;
146         for (int i = left; i < right; i++) {
147             if (comparator.compare(array[i], pivotValue) < 0) {
148                 swap(array, storeIndex, i);
149                 storeIndex++;
150             }
151         }
152         swap(array, right, storeIndex);
153         return storeIndex;
154     }
155 
swap(@onNull List<T> list, int left, int right)156     private static <T> void swap(@NonNull List<T> list, int left, int right) {
157         final T tmp = list.get(left);
158         list.set(left, list.get(right));
159         list.set(right, tmp);
160     }
161 
swap(@onNull int[] array, int left, int right)162     private static void swap(@NonNull int[] array, int left, int right) {
163         final int tmp = array[left];
164         array[left] = array[right];
165         array[right] = tmp;
166     }
167 
swap(@onNull long[] array, int left, int right)168     private static void swap(@NonNull long[] array, int left, int right) {
169         final long tmp = array[left];
170         array[left] = array[right];
171         array[right] = tmp;
172     }
173 
swap(@onNull T[] array, int left, int right)174     private static <T> void swap(@NonNull T[] array, int left, int right) {
175         final T tmp = array[left];
176         array[left] = array[right];
177         array[right] = tmp;
178     }
179 
180     /**
181      * Return the kth(0-based) smallest element from the given unsorted list.
182      *
183      * @param list The input list, it <b>will</b> be modified by the algorithm here.
184      * @param start The start offset of the list, inclusive.
185      * @param length The length of the sub list to be searched in.
186      * @param k The 0-based index.
187      * @param comparator The comparator which knows how to compare the elements in the list.
188      * @return The kth smallest element from the given list,
189      *         or IllegalArgumentException will be thrown if not found.
190      */
191     @Nullable
select(@onNull List<T> list, int start, int length, int k, @NonNull Comparator<? super T> comparator)192     public static <T> T select(@NonNull List<T> list, int start, int length, int k,
193             @NonNull Comparator<? super T> comparator) {
194         if (list == null || start < 0 || length <= 0 || list.size() < start + length
195                 || k < 0 || length <= k) {
196             throw new IllegalArgumentException();
197         }
198         return list.get(selectImpl(list, start, start + length - 1, k + start, comparator));
199     }
200 
201     /**
202      * Return the kth(0-based) smallest element from the given unsorted array.
203      *
204      * @param array The input array, it <b>will</b> be modified by the algorithm here.
205      * @param start The start offset of the array, inclusive.
206      * @param length The length of the sub array to be searched in.
207      * @param k The 0-based index to search for.
208      * @return The kth smallest element from the given array,
209      *         or IllegalArgumentException will be thrown if not found.
210      */
select(@onNull int[] array, int start, int length, int k)211     public static int select(@NonNull int[] array, int start, int length, int k) {
212         if (array == null || start < 0 || length <= 0 || array.length < start + length
213                 || k < 0 || length <= k) {
214             throw new IllegalArgumentException();
215         }
216         return array[selectImpl(array, start, start + length - 1, k + start)];
217     }
218 
219     /**
220      * Return the kth(0-based) smallest element from the given unsorted array.
221      *
222      * @param array The input array, it <b>will</b> be modified by the algorithm here.
223      * @param start The start offset of the array, inclusive.
224      * @param length The length of the sub array to be searched in.
225      * @param k The 0-based index to search for.
226      * @return The kth smallest element from the given array,
227      *         or IllegalArgumentException will be thrown if not found.
228      */
select(@onNull long[] array, int start, int length, int k)229     public static long select(@NonNull long[] array, int start, int length, int k) {
230         if (array == null || start < 0 || length <= 0 || array.length < start + length
231                 || k < 0 || length <= k) {
232             throw new IllegalArgumentException();
233         }
234         return array[selectImpl(array, start, start + length - 1, k + start)];
235     }
236 
237     /**
238      * Return the kth(0-based) smallest element from the given unsorted array.
239      *
240      * @param array The input array, it <b>will</b> be modified by the algorithm here.
241      * @param start The start offset of the array, inclusive.
242      * @param length The length of the sub array to be searched in.
243      * @param k The 0-based index to search for.
244      * @param comparator The comparator which knows how to compare the elements in the list.
245      * @return The kth smallest element from the given array,
246      *         or IllegalArgumentException will be thrown if not found.
247      */
select(@onNull T[] array, int start, int length, int k, @NonNull Comparator<? super T> comparator)248     public static <T> T select(@NonNull T[] array, int start, int length, int k,
249             @NonNull Comparator<? super T> comparator) {
250         if (array == null || start < 0 || length <= 0 || array.length < start + length
251                 || k < 0 || length <= k) {
252             throw new IllegalArgumentException();
253         }
254         return array[selectImpl(array, start, start + length - 1, k + start, comparator)];
255     }
256 }
257