1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "utils.h"
17
18 #include <fstream>
19 #include <sys/time.h>
20 #include <iostream>
21
22 #include "tflite/tools/bitmap_helpers.h"
23 #include "tflite/tools/get_topn.h"
24 #include "tflite/tools/log.h"
25
26 namespace tflite {
27 namespace label_classify {
28 constexpr int32_t DATA_PRINT_NUM = 1000;
29 constexpr int32_t DATA_EACHLINE_NUM = 1000;
30 constexpr int32_t SECOND_TO_MICROSECOND_RATIO = 1000000;
31 constexpr uint8_t WEIGHT_DIMENSION = 2;
32 constexpr uint8_t CHANNEL_DIMENSION = 3;
33
GetUs(struct timeval t)34 double GetUs(struct timeval t)
35 {
36 return (t.tv_sec * SECOND_TO_MICROSECOND_RATIO + t.tv_usec);
37 }
38
ReadLabelsFile(const string & fileName,std::vector<string> & result,size_t & foundLabelCount)39 TfLiteStatus ReadLabelsFile(const string& fileName, std::vector<string>& result, size_t& foundLabelCount)
40 {
41 std::ifstream file(fileName);
42 if (!file) {
43 LOG(ERROR) << "Labels file " << fileName << " not found";
44 return kTfLiteError;
45 }
46 result.clear();
47 string line;
48 while (std::getline(file, line)) {
49 result.push_back(line);
50 }
51 foundLabelCount = result.size();
52 const int32_t padding = 16;
53 while (result.size() % padding) {
54 result.emplace_back();
55 }
56
57 return kTfLiteOk;
58 }
59
GetInputNameAndShape(string & inputShapeString,std::map<string,std::vector<int>> & userInputShapes)60 void GetInputNameAndShape(string& inputShapeString, std::map<string, std::vector<int>>& userInputShapes)
61 {
62 if (inputShapeString == "") {
63 return;
64 }
65 size_t pos = inputShapeString.find_last_of(":");
66 string userInputName = inputShapeString.substr(0, pos);
67
68 string dimString = inputShapeString.substr(pos + 1);
69 size_t dimPos = dimString.find(",");
70 std::vector<int> inputDims;
71 while (dimPos != dimString.npos) {
72 inputDims.push_back(std::stoi(dimString.substr(0, dimPos)));
73 dimString = dimString.substr(dimPos + 1);
74 dimPos = dimString.find(",");
75 }
76 inputDims.push_back(std::stoi(dimString));
77 userInputShapes.insert(std::map<string, std::vector<int>>::value_type(userInputName, inputDims));
78 }
79
FilterDynamicInputs(Settings & settings,std::unique_ptr<tflite::Interpreter> & interpreter,std::map<int,std::vector<int>> & neededInputShapes)80 TfLiteStatus FilterDynamicInputs(Settings& settings, std::unique_ptr<tflite::Interpreter>& interpreter,
81 std::map<int, std::vector<int>>& neededInputShapes)
82 {
83 std::vector<int> inputIndexes = interpreter->inputs();
84 std::map<string, int> nameIndexs;
85 for (int i = 0; i < inputIndexes.size(); i++) {
86 LOG(INFO) << "input index: " << inputIndexes[i];
87 nameIndexs.insert(std::map<string, int>::value_type(interpreter->GetInputName(i), inputIndexes[i]));
88 }
89
90 if (settings.inputShape.find(":") == settings.inputShape.npos) {
91 LOG(ERROR) << "The format of input shapes string is not supported.";
92 return kTfLiteError;
93 }
94
95 // Get input names and shapes
96 std::map<string, std::vector<int>> userInputShapes;
97 string inputShapeString = settings.inputShape;
98 size_t pos = inputShapeString.find(";");
99 while (pos != inputShapeString.npos) {
100 GetInputNameAndShape(inputShapeString, userInputShapes);
101 inputShapeString = inputShapeString.substr(pos + 1);
102 pos = inputShapeString.find(";");
103 }
104 GetInputNameAndShape(inputShapeString, userInputShapes);
105
106 for (const auto& inputShape : userInputShapes) {
107 string inputName = inputShape.first;
108 auto findName = nameIndexs.find(inputName);
109 if (findName == nameIndexs.end()) {
110 LOG(ERROR) << "The input name is error: " << inputShape.first << ".";
111 return kTfLiteError;
112 } else {
113 neededInputShapes.insert(std::map<int, std::vector<int>>::value_type(findName->second, inputShape.second));
114 }
115 }
116
117 return kTfLiteOk;
118 }
119
PrintData(T * data,int32_t dataSize,int32_t printSize)120 template <class T> void PrintData(T* data, int32_t dataSize, int32_t printSize)
121 {
122 if (printSize > dataSize) {
123 printSize = dataSize;
124 }
125 for (int32_t i = 0; i < printSize; ++i) {
126 std::cout << static_cast<float>(*(data + i)) << "\t";
127 }
128 std::cout << std::endl;
129 }
130
PrintResult(std::unique_ptr<tflite::Interpreter> & interpreter)131 void PrintResult(std::unique_ptr<tflite::Interpreter>& interpreter)
132 {
133 for (int32_t index = 0; index < interpreter->outputs().size(); ++index) {
134 int32_t output_index = interpreter->outputs()[index];
135 TfLiteIntArray* outputsDims = interpreter->tensor(output_index)->dims;
136 int32_t dimSize = outputsDims->size;
137 int32_t outputTensorSize = 1;
138 for (int32_t i = 0; i < dimSize; ++i) {
139 outputTensorSize *= outputsDims->data[i];
140 }
141
142 TfLiteTensor* outputTensor = interpreter->tensor(output_index);
143 switch (outputTensor->type) {
144 case kTfLiteFloat32:
145 PrintData<float>(interpreter->typed_output_tensor<float>(index), outputTensorSize, DATA_PRINT_NUM);
146 break;
147 case kTfLiteInt32:
148 PrintData<int32_t>(interpreter->typed_output_tensor<int32_t>(index), outputTensorSize, DATA_PRINT_NUM);
149 break;
150 case kTfLiteUInt8:
151 PrintData<uint8_t>(interpreter->typed_output_tensor<uint8_t>(index), outputTensorSize, DATA_PRINT_NUM);
152 break;
153 case kTfLiteInt8:
154 PrintData<int8_t>(interpreter->typed_output_tensor<int8_t>(index), outputTensorSize, DATA_PRINT_NUM);
155 break;
156 default:
157 LOG(ERROR) << "Unsupportted tensor datatype: " << outputTensor->type << "!";
158 return;
159 }
160 }
161 }
162
AnalysisResults(Settings & settings,std::unique_ptr<tflite::Interpreter> & interpreter)163 void AnalysisResults(Settings& settings, std::unique_ptr<tflite::Interpreter>& interpreter)
164 {
165 const float threshold = 0.001f;
166 std::vector<std::pair<float, int32_t>> topResults;
167
168 if (settings.printResult) {
169 LOG(INFO) << "Outputs Data:";
170 PrintResult(interpreter);
171 }
172
173 int32_t output = interpreter->outputs()[0];
174 TfLiteIntArray* outputDims = interpreter->tensor(output)->dims;
175 // assume output dims to be something like (1, 1, ... ,size)
176 auto outputSize = outputDims->data[outputDims->size - 1];
177
178 auto tfType = interpreter->tensor(output)->type;
179 switch (tfType) {
180 case kTfLiteFloat32:
181 GetTopN<float>(interpreter->typed_output_tensor<float>(0), outputSize, settings.numberOfResults, threshold,
182 &topResults, settings.inputType);
183 break;
184 case kTfLiteInt8:
185 GetTopN<int8_t>(interpreter->typed_output_tensor<int8_t>(0), outputSize, settings.numberOfResults,
186 threshold, &topResults, settings.inputType);
187 break;
188 case kTfLiteUInt8:
189 GetTopN<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0), outputSize, settings.numberOfResults,
190 threshold, &topResults, settings.inputType);
191 break;
192 case kTfLiteInt64:
193 GetTopN<int64_t>(interpreter->typed_output_tensor<int64_t>(0), outputSize, settings.numberOfResults,
194 threshold, &topResults, settings.inputType);
195 break;
196 default:
197 LOG(ERROR) << "cannot handle output type " << tfType << " yet";
198 return;
199 }
200
201 std::vector<string> labels;
202 size_t labelCount;
203
204 if (ReadLabelsFile(settings.labelsFileName, labels, labelCount) != kTfLiteOk) {
205 return;
206 }
207 for (const auto& result : topResults) {
208 const float confidence = result.first;
209 const int32_t index = result.second;
210 LOG(INFO) << confidence << ": " << index << " " << labels[index];
211 }
212 }
213
ImportData(Settings & settings,std::vector<int> & imageSize,std::unique_ptr<tflite::Interpreter> & interpreter)214 void ImportData(Settings& settings, std::vector<int>& imageSize, std::unique_ptr<tflite::Interpreter>& interpreter)
215 {
216 ImageInfo inputImageInfo = {imageSize[0], imageSize[1], imageSize[2]};
217 std::vector<uint8_t> in;
218 ReadBmp(settings.inputBmpName, inputImageInfo, &settings, in);
219
220 int32_t input = interpreter->inputs()[0];
221 if (settings.verbose) {
222 LOG(INFO) << "input: " << input;
223 }
224
225 // get input dimension from the model.
226 TfLiteIntArray* dims = interpreter->tensor(input)->dims;
227 ImageInfo wantedimageInfo;
228 wantedimageInfo.height = dims->data[1];
229 wantedimageInfo.width = dims->data[WEIGHT_DIMENSION];
230 wantedimageInfo.channels = (dims->size > CHANNEL_DIMENSION) ? dims->data[CHANNEL_DIMENSION] : 1;
231
232 settings.inputType = interpreter->tensor(input)->type;
233 switch (settings.inputType) {
234 case kTfLiteFloat32:
235 Resize<float>(interpreter->typed_tensor<float>(input), in.data(), inputImageInfo, wantedimageInfo,
236 &settings);
237 break;
238 case kTfLiteInt8:
239 Resize<int8_t>(interpreter->typed_tensor<int8_t>(input), in.data(), inputImageInfo, wantedimageInfo,
240 &settings);
241 break;
242 case kTfLiteUInt8:
243 Resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(), inputImageInfo, wantedimageInfo,
244 &settings);
245 break;
246 case kTfLiteInt64:
247 Resize<int64_t>(interpreter->typed_tensor<int64_t>(input), in.data(), inputImageInfo, wantedimageInfo,
248 &settings);
249 break;
250 default:
251 LOG(ERROR) << "cannot handle input type " << settings.inputType << " yet";
252 return;
253 }
254 }
255
IsEqualShape(int tensorIndex,const std::vector<int> & dims,std::unique_ptr<tflite::Interpreter> & interpreter)256 bool IsEqualShape(int tensorIndex, const std::vector<int>& dims, std::unique_ptr<tflite::Interpreter>& interpreter)
257 {
258 TfLiteTensor* tensor = interpreter->tensor(tensorIndex);
259 for (int i = 0; i < tensor->dims->size; ++i) {
260 if (tensor->dims->data[i] != dims[i]) {
261 return false;
262 }
263 }
264 return true;
265 }
266 } // namespace label_classify
267 } // namespace tflite