1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_CLASSIFY_GET_TOP_N_H
17 #define TENSORFLOW_LITE_EXAMPLES_LABEL_CLASSIFY_GET_TOP_N_H
18
19 #include <algorithm>
20 #include <functional>
21 #include <queue>
22
23 #include "tensorflow/lite/c/common.h"
24
25 namespace tflite {
26 namespace label_classify {
27 template <class T>
GetTopN(T * prediction,int32_t predictionSize,size_t numResults,float threshold,std::vector<std::pair<float,int32_t>> * topResults,TfLiteType inputType)28 void GetTopN(T* prediction, int32_t predictionSize, size_t numResults, float threshold,
29 std::vector<std::pair<float, int32_t>>* topResults, TfLiteType inputType)
30 {
31 // Will contain top N results in ascending order.
32 std::priority_queue<std::pair<float, int32_t>, std::vector<std::pair<float, int32_t>>,
33 std::greater<std::pair<float, int32_t>>>
34 topResultPQ;
35
36 const long count = predictionSize; // NOLINT(runtime/int32_t)
37 float value = 0.0;
38 float intNormalizedFactor = 256.0;
39 float uintNormalizedFactor = 255.0;
40 uint32_t offsetNumber = 128;
41
42 for (int32_t i = 0; i < count; ++i) {
43 switch (inputType) {
44 case kTfLiteFloat32:
45 value = prediction[i];
46 break;
47 case kTfLiteInt8:
48 value = (prediction[i] + offsetNumber) / intNormalizedFactor;
49 break;
50 case kTfLiteUInt8:
51 value = prediction[i] / uintNormalizedFactor;
52 break;
53 default:
54 break;
55 }
56
57 // Only add it if it beats the threshold and has a chance at being in the top N.
58 if (value < threshold) {
59 continue;
60 }
61
62 topResultPQ.push(std::pair<float, int32_t>(value, i));
63
64 // If at capacity, kick the smallest value out.
65 if (topResultPQ.size() > numResults) {
66 topResultPQ.pop();
67 }
68 }
69
70 // Copy to output vector and reverse into descending order.
71 while (!topResultPQ.empty()) {
72 topResults->push_back(topResultPQ.top());
73 topResultPQ.pop();
74 }
75
76 std::reverse(topResults->begin(), topResults->end());
77 }
78
79 // explicit instantiation so that we can use them otherwhere
80 template void GetTopN<float>(float*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
81 template void GetTopN<int8_t>(int8_t*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
82 template void GetTopN<uint8_t>(uint8_t*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
83 template void GetTopN<int64_t>(int64_t*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
84 } // namespace label_classify
85 } // namespace tflite
86
87 #endif // TENSORFLOW_LITE_EXAMPLES_LABEL_CLASSIFY_GET_TOP_N_H
88