1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "label_classify.h"
17 
18 #include <fcntl.h>
19 #include <getopt.h>
20 #include <sys/time.h>
21 #include <sys/types.h>
22 #include <sys/uio.h>
23 #include <unistd.h>
24 
25 #include <cstdarg>
26 #include <cstdio>
27 #include <cstdlib>
28 #include <fstream>
29 #include <iomanip>
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <unordered_set>
34 #include <vector>
35 
36 #include "tensorflow/lite/kernels/register.h"
37 #include "tensorflow/lite/optional_debug_tools.h"
38 #include "tensorflow/lite/string_util.h"
39 #include "tensorflow/lite/tools/command_line_flags.h"
40 #include "tensorflow/lite/tools/delegates/delegate_provider.h"
41 
42 #include "log.h"
43 #include "utils.h"
44 
45 namespace tflite {
46 namespace label_classify {
47 using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
48 using ProvidedDelegateList = tflite::tools::ProvidedDelegateList;
49 constexpr int BASE_NUMBER = 10;
50 constexpr int CONVERSION_RATE = 1000;
51 static struct option LONG_OPTIONS[] = {
52     {"help", no_argument, nullptr, 'h'},
53     {"use_nnrt", required_argument, nullptr, 'a'},
54     {"count", required_argument, nullptr, 'c'},
55     {"image", required_argument, nullptr, 'i'},
56     {"labels", required_argument, nullptr, 'l'},
57     {"tflite_model", required_argument, nullptr, 'm'},
58     {"num_results", required_argument, nullptr, 'n'},
59     {"input_mean", required_argument, nullptr, 'b'},
60     {"input_std", required_argument, nullptr, 's'},
61     {"verbose", required_argument, nullptr, 'v'},
62     {"warmup_nums", required_argument, nullptr, 'w'},
63     {"print_result", required_argument, nullptr, 'z'},
64     {"input_shape", required_argument, nullptr, 'p'},
65     {nullptr, 0, nullptr, 0},
66 };
67 
68 class DelegateProviders {
69 public:
DelegateProviders()70     DelegateProviders() : m_delegateListUtil(&params)
71     {
72         m_delegateListUtil.AddAllDelegateParams();  // Add all registered delegate params to the contained 'params_'.
73     }
74 
~DelegateProviders()75     ~DelegateProviders() {}
76 
InitFromCmdlineArgs(int32_t * argc,const char ** argv)77     bool InitFromCmdlineArgs(int32_t* argc, const char** argv)
78     {
79         std::vector<tflite::Flag> flags;
80         m_delegateListUtil.AppendCmdlineFlags(&flags);
81 
82         const bool parseResult = Flags::Parse(argc, argv, flags);
83         if (!parseResult) {
84             std::string usage = Flags::Usage(argv[0], flags);
85             LOG(ERROR) << usage;
86         }
87         return parseResult;
88     }
89 
MergeSettingsIntoParams(const Settings & settings)90     void MergeSettingsIntoParams(const Settings& settings)
91     {
92         if (settings.accel) {
93             if (!params.HasParam("use_nnrt")) {
94                 LOG(WARN) << "NNRT deleate execution provider isn't linked or NNRT "
95                           << "delegate isn't supported on the platform!";
96             } else {
97                 params.Set<bool>("use_nnrt", true);
98             }
99         }
100     }
101 
CreateAllDelegates() const102     std::vector<ProvidedDelegateList::ProvidedDelegate> CreateAllDelegates() const
103     {
104         return m_delegateListUtil.CreateAllRankedDelegates();
105     }
106 
107 private:
108     // Contain delegate-related parameters that are initialized from command-line flags.
109     tflite::tools::ToolParams params;
110 
111     // A helper to create TfLite delegates.
112     ProvidedDelegateList m_delegateListUtil;
113 };
114 
PrepareModel(Settings & settings,std::unique_ptr<tflite::Interpreter> & interpreter,DelegateProviders & delegateProviders)115 void PrepareModel(Settings& settings, std::unique_ptr<tflite::Interpreter>& interpreter,
116     DelegateProviders& delegateProviders)
117 {
118     const std::vector<int32_t> inputs = interpreter->inputs();
119     const std::vector<int32_t> outputs = interpreter->outputs();
120 
121     if (settings.verbose) {
122         LOG(INFO) << "number of inputs: " << inputs.size();
123         LOG(INFO) << "number of outputs: " << outputs.size();
124     }
125 
126     std::map<int, std::vector<int>> neededInputShapes;
127     if (settings.inputShape != "") {
128         if (FilterDynamicInputs(settings, interpreter, neededInputShapes) != kTfLiteOk) {
129             return;
130         }
131     }
132 
133     delegateProviders.MergeSettingsIntoParams(settings);
134     auto delegates = delegateProviders.CreateAllDelegates();
135 
136     for (auto& delegate : delegates) {
137         const auto delegateName = delegate.provider->GetName();
138         if (interpreter->ModifyGraphWithDelegate(std::move(delegate.delegate)) != kTfLiteOk) {
139             LOG(ERROR) << "Failed to apply " << delegateName << " delegate.";
140             return;
141         } else {
142             LOG(INFO) << "Applied " << delegateName << " delegate.";
143         }
144     }
145 
146     if (settings.inputShape != "") {
147         for (const auto& inputShape : neededInputShapes) {
148             if (IsEqualShape(inputShape.first, inputShape.second, interpreter)) {
149                 LOG(WARNING) << "The input shape is same as the model shape, not resize.";
150                 continue;
151             }
152             if (interpreter->ResizeInputTensor(inputShape.first, inputShape.second) != kTfLiteOk) {
153                 LOG(ERROR) << "Fail to resize index " << inputShape.first << ".";
154                 return;
155             } else {
156                 LOG(INFO) << "Susccess to resize index " << inputShape.first << ".";
157             }
158         }
159     }
160 
161     if (interpreter->AllocateTensors() != kTfLiteOk) {
162         LOG(ERROR) << "Failed to allocate tensors!";
163         return;
164     }
165 
166     if (settings.verbose) {
167         PrintInterpreterState(interpreter.get());
168     }
169 }
170 
LogInterpreterParams(Settings & settings,std::unique_ptr<tflite::Interpreter> & interpreter)171 void LogInterpreterParams(Settings& settings, std::unique_ptr<tflite::Interpreter>& interpreter)
172 {
173     if (!interpreter) {
174         LOG(ERROR) << "Failed to construct interpreter";
175         return;
176     }
177 
178     if (settings.verbose) {
179         LOG(INFO) << "tensors size: " << interpreter->tensors_size();
180         LOG(INFO) << "nodes size: " << interpreter->nodes_size();
181         LOG(INFO) << "inputs: " << interpreter->inputs().size();
182         LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0);
183 
184         size_t tSize = interpreter->tensors_size();
185         for (size_t i = 0; i < tSize; ++i) {
186             if (interpreter->tensor(i)->name) {
187                 LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " << interpreter->tensor(i)->bytes <<
188                     ", " << interpreter->tensor(i)->type << ", " << interpreter->tensor(i)->params.scale << ", " <<
189                     interpreter->tensor(i)->params.zero_point;
190             }
191         }
192     }
193 }
194 
InferenceModel(Settings & settings,DelegateProviders & delegateProviders)195 void InferenceModel(Settings& settings, DelegateProviders& delegateProviders)
196 {
197     if (!settings.modelName.c_str()) {
198         LOG(ERROR) << "no model file name";
199         return;
200     }
201     std::unique_ptr<tflite::FlatBufferModel> model;
202     std::unique_ptr<tflite::Interpreter> interpreter;
203     model = tflite::FlatBufferModel::BuildFromFile(settings.modelName.c_str());
204     if (!model) {
205         LOG(ERROR) << "Failed to mmap model " << settings.modelName;
206         return;
207     }
208 
209     settings.model = model.get();
210     model->error_reporter();
211     tflite::ops::builtin::BuiltinOpResolver resolver;
212     tflite::InterpreterBuilder(*model, resolver)(&interpreter);
213     if (!interpreter) {
214         LOG(ERROR) << "Failed to construct interpreter, please check the model.";
215         return;
216     }
217 
218     LogInterpreterParams(settings, interpreter);
219 
220     // set settings input type
221     PrepareModel(settings, interpreter, delegateProviders);
222     std::vector<int> imageSize { 224, 224, 3};
223     ImportData(settings, imageSize, interpreter);
224 
225     if (settings.loopCount > 0 && settings.numberOfWarmupRuns > 0) {
226         LOG(INFO) << "Warm-up for " << settings.numberOfWarmupRuns << " times";
227         for (int32_t i = 0; i < settings.numberOfWarmupRuns; ++i) {
228             if (interpreter->Invoke() != kTfLiteOk) {
229                 LOG(ERROR) << "Failed to invoke tflite!";
230                 return;
231             }
232         }
233     }
234 
235     struct timeval startTime, stopTime;
236     LOG(INFO) << "Invoke for " << settings.loopCount << " times";
237     gettimeofday(&startTime, nullptr);
238     for (int32_t i = 0; i < settings.loopCount; ++i) {
239         if (interpreter->Invoke() != kTfLiteOk) {
240             LOG(ERROR) << "Failed to invoke tflite!";
241             return;
242         }
243     }
244 
245     gettimeofday(&stopTime, nullptr);
246     LOG(INFO) << "invoked, average time: " <<
247         (GetUs(stopTime) - GetUs(startTime)) / (settings.loopCount * CONVERSION_RATE) << " ms";
248     AnalysisResults(settings, interpreter);
249 }
250 
DisplayUsage()251 void DisplayUsage()
252 {
253     LOG(INFO) << "label_classify -m xxx.tflite -i xxx.bmp -l xxx.txt -c 1 -a 1\n"
254               << "\t--help,         -h: show the usage of the demo\n"
255               << "\t--use_nnrt,     -a: [0|1], 1 refers to use NNRT\n"
256               << "\t--input_mean,   -b: input mean\n"
257               << "\t--count,        -c: loop interpreter->Invoke() for certain times\n"
258               << "\t--image,        -i: image_name.bmp\n"
259               << "\t--labels,       -l: labels for the model\n"
260               << "\t--tflite_model, -m: modelName.tflite\n"
261               << "\t--num_results,  -n: number of results to show\n"
262               << "\t--input_std,    -s: input standard deviation\n"
263               << "\t--verbose,      -v: [0|1] print more information\n"
264               << "\t--warmup_nums,  -w: number of warmup runs\n"
265               << "\t--print_result, -z: flag to print results\n"
266               << "\t--input_shape,  -p: Indicates the specified dynamic input node and the corresponding shape.\n";
267 }
268 
InitSettings(int32_t argc,char ** argv,Settings & settings)269 int32_t InitSettings(int32_t argc, char** argv, Settings& settings)
270 {
271     // getopt_long stores the option index here.
272     int32_t optionIndex = 0;
273     while ((optionIndex = getopt_long(argc, argv, "a:b:c:h:i:l:m:n:p:s:v:w:z:", LONG_OPTIONS, nullptr)) != -1) {
274         switch (optionIndex) {
275             case 'a':
276                 settings.accel = strtol(optarg, nullptr, BASE_NUMBER);
277                 break;
278             case 'b':
279                 settings.inputMean = strtod(optarg, nullptr);
280                 break;
281             case 'c':
282                 settings.loopCount = strtol(optarg, nullptr, BASE_NUMBER);
283                 break;
284             case 'i':
285                 settings.inputBmpName = optarg;
286                 break;
287             case 'l':
288                 settings.labelsFileName = optarg;
289                 break;
290             case 'm':
291                 settings.modelName = optarg;
292                 break;
293             case 'n':
294                 settings.numberOfResults = strtol(optarg, nullptr, BASE_NUMBER);
295                 break;
296             case 'p':
297                 settings.inputShape = optarg;
298                 break;
299             case 's':
300                 settings.inputStd = strtod(optarg, nullptr);
301                 break;
302             case 'v':
303                 settings.verbose = strtol(optarg, nullptr, BASE_NUMBER);
304                 break;
305             case 'w':
306                 settings.numberOfWarmupRuns = strtol(optarg, nullptr, BASE_NUMBER);
307                 break;
308             case 'z':
309                 settings.printResult = strtol(optarg, nullptr, BASE_NUMBER);
310                 break;
311             case 'h':
312             case '?':
313                 // getopt_long already printed an error message.
314                 DisplayUsage();
315                 return -1;
316             default:
317                 return -1;
318         }
319     }
320 
321     return 0;
322 }
323 
Main(int32_t argc,char ** argv)324 int32_t Main(int32_t argc, char** argv)
325 {
326     if (argc <= 1) {
327         DisplayUsage();
328         return EXIT_FAILURE;
329     }
330 
331     DelegateProviders delegateProviders;
332     bool parseResult = delegateProviders.InitFromCmdlineArgs(&argc, const_cast<const char**>(argv));
333     if (!parseResult) {
334         return EXIT_FAILURE;
335     }
336 
337     Settings settings;
338     if (InitSettings(argc, argv, settings) == -1) {
339         return EXIT_FAILURE;
340     };
341 
342     InferenceModel(settings, delegateProviders);
343     return 0;
344 }
345 } // namespace label_classify
346 } // namespace tflite
347 
main(int32_t argc,char ** argv)348 int32_t main(int32_t argc, char** argv)
349 {
350     return tflite::label_classify::Main(argc, argv);
351 }
352