1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_CLASSIFY_GET_TOP_N_H
17 #define TENSORFLOW_LITE_EXAMPLES_LABEL_CLASSIFY_GET_TOP_N_H
18 
19 #include <algorithm>
20 #include <functional>
21 #include <queue>
22 
23 #include "tensorflow/lite/c/common.h"
24 
25 namespace tflite {
26 namespace label_classify {
27 template <class T>
GetTopN(T * prediction,int32_t predictionSize,size_t numResults,float threshold,std::vector<std::pair<float,int32_t>> * topResults,TfLiteType inputType)28 void GetTopN(T* prediction, int32_t predictionSize, size_t numResults, float threshold,
29     std::vector<std::pair<float, int32_t>>* topResults, TfLiteType inputType)
30 {
31     // Will contain top N results in ascending order.
32     std::priority_queue<std::pair<float, int32_t>, std::vector<std::pair<float, int32_t>>,
33         std::greater<std::pair<float, int32_t>>>
34         topResultPQ;
35 
36     const long count = predictionSize; // NOLINT(runtime/int32_t)
37     float value = 0.0;
38     float intNormalizedFactor = 256.0;
39     float uintNormalizedFactor = 255.0;
40     uint32_t offsetNumber = 128;
41 
42     for (int32_t i = 0; i < count; ++i) {
43         switch (inputType) {
44             case kTfLiteFloat32:
45                 value = prediction[i];
46                 break;
47             case kTfLiteInt8:
48                 value = (prediction[i] + offsetNumber) / intNormalizedFactor;
49                 break;
50             case kTfLiteUInt8:
51                 value = prediction[i] / uintNormalizedFactor;
52                 break;
53             default:
54                 break;
55         }
56 
57         // Only add it if it beats the threshold and has a chance at being in the top N.
58         if (value < threshold) {
59             continue;
60         }
61 
62         topResultPQ.push(std::pair<float, int32_t>(value, i));
63 
64         // If at capacity, kick the smallest value out.
65         if (topResultPQ.size() > numResults) {
66             topResultPQ.pop();
67         }
68     }
69 
70     // Copy to output vector and reverse into descending order.
71     while (!topResultPQ.empty()) {
72         topResults->push_back(topResultPQ.top());
73         topResultPQ.pop();
74     }
75 
76     std::reverse(topResults->begin(), topResults->end());
77 }
78 
79 // explicit instantiation so that we can use them otherwhere
80 template void GetTopN<float>(float*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
81 template void GetTopN<int8_t>(int8_t*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
82 template void GetTopN<uint8_t>(uint8_t*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
83 template void GetTopN<int64_t>(int64_t*, int32_t, size_t, float, std::vector<std::pair<float, int32_t>>*, TfLiteType);
84 }  // namespace label_classify
85 }  // namespace tflite
86 
87 #endif  // TENSORFLOW_LITE_EXAMPLES_LABEL_CLASSIFY_GET_TOP_N_H
88