1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "label_classify.h"
17
18 #include <fcntl.h>
19 #include <getopt.h>
20 #include <sys/time.h>
21 #include <sys/types.h>
22 #include <sys/uio.h>
23 #include <unistd.h>
24
25 #include <cstdarg>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <fstream>
29 #include <iomanip>
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <unordered_set>
34 #include <vector>
35
36 #include "tensorflow/lite/kernels/register.h"
37 #include "tensorflow/lite/optional_debug_tools.h"
38 #include "tensorflow/lite/string_util.h"
39 #include "tensorflow/lite/tools/command_line_flags.h"
40 #include "tensorflow/lite/tools/delegates/delegate_provider.h"
41
42 #include "log.h"
43 #include "utils.h"
44
45 namespace tflite {
46 namespace label_classify {
47 using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
48 using ProvidedDelegateList = tflite::tools::ProvidedDelegateList;
49 constexpr int BASE_NUMBER = 10;
50 constexpr int CONVERSION_RATE = 1000;
51 static struct option LONG_OPTIONS[] = {
52 {"help", no_argument, nullptr, 'h'},
53 {"use_nnrt", required_argument, nullptr, 'a'},
54 {"count", required_argument, nullptr, 'c'},
55 {"image", required_argument, nullptr, 'i'},
56 {"labels", required_argument, nullptr, 'l'},
57 {"tflite_model", required_argument, nullptr, 'm'},
58 {"num_results", required_argument, nullptr, 'n'},
59 {"input_mean", required_argument, nullptr, 'b'},
60 {"input_std", required_argument, nullptr, 's'},
61 {"verbose", required_argument, nullptr, 'v'},
62 {"warmup_nums", required_argument, nullptr, 'w'},
63 {"print_result", required_argument, nullptr, 'z'},
64 {"input_shape", required_argument, nullptr, 'p'},
65 {nullptr, 0, nullptr, 0},
66 };
67
68 class DelegateProviders {
69 public:
DelegateProviders()70 DelegateProviders() : m_delegateListUtil(¶ms)
71 {
72 m_delegateListUtil.AddAllDelegateParams(); // Add all registered delegate params to the contained 'params_'.
73 }
74
~DelegateProviders()75 ~DelegateProviders() {}
76
InitFromCmdlineArgs(int32_t * argc,const char ** argv)77 bool InitFromCmdlineArgs(int32_t* argc, const char** argv)
78 {
79 std::vector<tflite::Flag> flags;
80 m_delegateListUtil.AppendCmdlineFlags(&flags);
81
82 const bool parseResult = Flags::Parse(argc, argv, flags);
83 if (!parseResult) {
84 std::string usage = Flags::Usage(argv[0], flags);
85 LOG(ERROR) << usage;
86 }
87 return parseResult;
88 }
89
MergeSettingsIntoParams(const Settings & settings)90 void MergeSettingsIntoParams(const Settings& settings)
91 {
92 if (settings.accel) {
93 if (!params.HasParam("use_nnrt")) {
94 LOG(WARN) << "NNRT deleate execution provider isn't linked or NNRT "
95 << "delegate isn't supported on the platform!";
96 } else {
97 params.Set<bool>("use_nnrt", true);
98 }
99 }
100 }
101
CreateAllDelegates() const102 std::vector<ProvidedDelegateList::ProvidedDelegate> CreateAllDelegates() const
103 {
104 return m_delegateListUtil.CreateAllRankedDelegates();
105 }
106
107 private:
108 // Contain delegate-related parameters that are initialized from command-line flags.
109 tflite::tools::ToolParams params;
110
111 // A helper to create TfLite delegates.
112 ProvidedDelegateList m_delegateListUtil;
113 };
114
PrepareModel(Settings & settings,std::unique_ptr<tflite::Interpreter> & interpreter,DelegateProviders & delegateProviders)115 void PrepareModel(Settings& settings, std::unique_ptr<tflite::Interpreter>& interpreter,
116 DelegateProviders& delegateProviders)
117 {
118 const std::vector<int32_t> inputs = interpreter->inputs();
119 const std::vector<int32_t> outputs = interpreter->outputs();
120
121 if (settings.verbose) {
122 LOG(INFO) << "number of inputs: " << inputs.size();
123 LOG(INFO) << "number of outputs: " << outputs.size();
124 }
125
126 std::map<int, std::vector<int>> neededInputShapes;
127 if (settings.inputShape != "") {
128 if (FilterDynamicInputs(settings, interpreter, neededInputShapes) != kTfLiteOk) {
129 return;
130 }
131 }
132
133 delegateProviders.MergeSettingsIntoParams(settings);
134 auto delegates = delegateProviders.CreateAllDelegates();
135
136 for (auto& delegate : delegates) {
137 const auto delegateName = delegate.provider->GetName();
138 if (interpreter->ModifyGraphWithDelegate(std::move(delegate.delegate)) != kTfLiteOk) {
139 LOG(ERROR) << "Failed to apply " << delegateName << " delegate.";
140 return;
141 } else {
142 LOG(INFO) << "Applied " << delegateName << " delegate.";
143 }
144 }
145
146 if (settings.inputShape != "") {
147 for (const auto& inputShape : neededInputShapes) {
148 if (IsEqualShape(inputShape.first, inputShape.second, interpreter)) {
149 LOG(WARNING) << "The input shape is same as the model shape, not resize.";
150 continue;
151 }
152 if (interpreter->ResizeInputTensor(inputShape.first, inputShape.second) != kTfLiteOk) {
153 LOG(ERROR) << "Fail to resize index " << inputShape.first << ".";
154 return;
155 } else {
156 LOG(INFO) << "Susccess to resize index " << inputShape.first << ".";
157 }
158 }
159 }
160
161 if (interpreter->AllocateTensors() != kTfLiteOk) {
162 LOG(ERROR) << "Failed to allocate tensors!";
163 return;
164 }
165
166 if (settings.verbose) {
167 PrintInterpreterState(interpreter.get());
168 }
169 }
170
LogInterpreterParams(Settings & settings,std::unique_ptr<tflite::Interpreter> & interpreter)171 void LogInterpreterParams(Settings& settings, std::unique_ptr<tflite::Interpreter>& interpreter)
172 {
173 if (!interpreter) {
174 LOG(ERROR) << "Failed to construct interpreter";
175 return;
176 }
177
178 if (settings.verbose) {
179 LOG(INFO) << "tensors size: " << interpreter->tensors_size();
180 LOG(INFO) << "nodes size: " << interpreter->nodes_size();
181 LOG(INFO) << "inputs: " << interpreter->inputs().size();
182 LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0);
183
184 size_t tSize = interpreter->tensors_size();
185 for (size_t i = 0; i < tSize; ++i) {
186 if (interpreter->tensor(i)->name) {
187 LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " << interpreter->tensor(i)->bytes <<
188 ", " << interpreter->tensor(i)->type << ", " << interpreter->tensor(i)->params.scale << ", " <<
189 interpreter->tensor(i)->params.zero_point;
190 }
191 }
192 }
193 }
194
InferenceModel(Settings & settings,DelegateProviders & delegateProviders)195 void InferenceModel(Settings& settings, DelegateProviders& delegateProviders)
196 {
197 if (!settings.modelName.c_str()) {
198 LOG(ERROR) << "no model file name";
199 return;
200 }
201 std::unique_ptr<tflite::FlatBufferModel> model;
202 std::unique_ptr<tflite::Interpreter> interpreter;
203 model = tflite::FlatBufferModel::BuildFromFile(settings.modelName.c_str());
204 if (!model) {
205 LOG(ERROR) << "Failed to mmap model " << settings.modelName;
206 return;
207 }
208
209 settings.model = model.get();
210 model->error_reporter();
211 tflite::ops::builtin::BuiltinOpResolver resolver;
212 tflite::InterpreterBuilder(*model, resolver)(&interpreter);
213 if (!interpreter) {
214 LOG(ERROR) << "Failed to construct interpreter, please check the model.";
215 return;
216 }
217
218 LogInterpreterParams(settings, interpreter);
219
220 // set settings input type
221 PrepareModel(settings, interpreter, delegateProviders);
222 std::vector<int> imageSize { 224, 224, 3};
223 ImportData(settings, imageSize, interpreter);
224
225 if (settings.loopCount > 0 && settings.numberOfWarmupRuns > 0) {
226 LOG(INFO) << "Warm-up for " << settings.numberOfWarmupRuns << " times";
227 for (int32_t i = 0; i < settings.numberOfWarmupRuns; ++i) {
228 if (interpreter->Invoke() != kTfLiteOk) {
229 LOG(ERROR) << "Failed to invoke tflite!";
230 return;
231 }
232 }
233 }
234
235 struct timeval startTime, stopTime;
236 LOG(INFO) << "Invoke for " << settings.loopCount << " times";
237 gettimeofday(&startTime, nullptr);
238 for (int32_t i = 0; i < settings.loopCount; ++i) {
239 if (interpreter->Invoke() != kTfLiteOk) {
240 LOG(ERROR) << "Failed to invoke tflite!";
241 return;
242 }
243 }
244
245 gettimeofday(&stopTime, nullptr);
246 LOG(INFO) << "invoked, average time: " <<
247 (GetUs(stopTime) - GetUs(startTime)) / (settings.loopCount * CONVERSION_RATE) << " ms";
248 AnalysisResults(settings, interpreter);
249 }
250
DisplayUsage()251 void DisplayUsage()
252 {
253 LOG(INFO) << "label_classify -m xxx.tflite -i xxx.bmp -l xxx.txt -c 1 -a 1\n"
254 << "\t--help, -h: show the usage of the demo\n"
255 << "\t--use_nnrt, -a: [0|1], 1 refers to use NNRT\n"
256 << "\t--input_mean, -b: input mean\n"
257 << "\t--count, -c: loop interpreter->Invoke() for certain times\n"
258 << "\t--image, -i: image_name.bmp\n"
259 << "\t--labels, -l: labels for the model\n"
260 << "\t--tflite_model, -m: modelName.tflite\n"
261 << "\t--num_results, -n: number of results to show\n"
262 << "\t--input_std, -s: input standard deviation\n"
263 << "\t--verbose, -v: [0|1] print more information\n"
264 << "\t--warmup_nums, -w: number of warmup runs\n"
265 << "\t--print_result, -z: flag to print results\n"
266 << "\t--input_shape, -p: Indicates the specified dynamic input node and the corresponding shape.\n";
267 }
268
InitSettings(int32_t argc,char ** argv,Settings & settings)269 int32_t InitSettings(int32_t argc, char** argv, Settings& settings)
270 {
271 // getopt_long stores the option index here.
272 int32_t optionIndex = 0;
273 while ((optionIndex = getopt_long(argc, argv, "a:b:c:h:i:l:m:n:p:s:v:w:z:", LONG_OPTIONS, nullptr)) != -1) {
274 switch (optionIndex) {
275 case 'a':
276 settings.accel = strtol(optarg, nullptr, BASE_NUMBER);
277 break;
278 case 'b':
279 settings.inputMean = strtod(optarg, nullptr);
280 break;
281 case 'c':
282 settings.loopCount = strtol(optarg, nullptr, BASE_NUMBER);
283 break;
284 case 'i':
285 settings.inputBmpName = optarg;
286 break;
287 case 'l':
288 settings.labelsFileName = optarg;
289 break;
290 case 'm':
291 settings.modelName = optarg;
292 break;
293 case 'n':
294 settings.numberOfResults = strtol(optarg, nullptr, BASE_NUMBER);
295 break;
296 case 'p':
297 settings.inputShape = optarg;
298 break;
299 case 's':
300 settings.inputStd = strtod(optarg, nullptr);
301 break;
302 case 'v':
303 settings.verbose = strtol(optarg, nullptr, BASE_NUMBER);
304 break;
305 case 'w':
306 settings.numberOfWarmupRuns = strtol(optarg, nullptr, BASE_NUMBER);
307 break;
308 case 'z':
309 settings.printResult = strtol(optarg, nullptr, BASE_NUMBER);
310 break;
311 case 'h':
312 case '?':
313 // getopt_long already printed an error message.
314 DisplayUsage();
315 return -1;
316 default:
317 return -1;
318 }
319 }
320
321 return 0;
322 }
323
Main(int32_t argc,char ** argv)324 int32_t Main(int32_t argc, char** argv)
325 {
326 if (argc <= 1) {
327 DisplayUsage();
328 return EXIT_FAILURE;
329 }
330
331 DelegateProviders delegateProviders;
332 bool parseResult = delegateProviders.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
333 if (!parseResult) {
334 return EXIT_FAILURE;
335 }
336
337 Settings settings;
338 if (InitSettings(argc, argv, settings) == -1) {
339 return EXIT_FAILURE;
340 };
341
342 InferenceModel(settings, delegateProviders);
343 return 0;
344 }
345 } // namespace label_classify
346 } // namespace tflite
347
main(int32_t argc,char ** argv)348 int32_t main(int32_t argc, char** argv)
349 {
350 return tflite::label_classify::Main(argc, argv);
351 }
352