1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <string>
17 #include <utility>
18 
19 #include "tensorflow/lite/tools/delegates/delegate_provider.h"
20 
21 #include "nnrt_delegate.h"
22 #include "../nnrt/nnrt_implementation.h"
23 
24 namespace tflite {
25 namespace tools {
26 constexpr int32_t DEFAULT_THREADS = 1;
27 constexpr int32_t DEFAULT_DELEGATE_NUM = -1;
28 class NnrtDelegateProvider : public DelegateProvider {
29 public:
NnrtDelegateProvider()30     NnrtDelegateProvider()
31     {
32         default_params_.AddParam("use_nnrt", ToolParam::Create<bool>(false));
33         default_params_.AddParam("performance", ToolParam::Create<std::string>(""));
34         default_params_.AddParam("priority", ToolParam::Create<std::string>(""));
35         default_params_.AddParam("device", ToolParam::Create<std::string>(""));
36         default_params_.AddParam("cache_dir", ToolParam::Create<std::string>(""));
37         default_params_.AddParam("model_token", ToolParam::Create<std::string>(""));
38         default_params_.AddParam("max_delegate_num", ToolParam::Create<int32_t>(DEFAULT_DELEGATE_NUM));
39         default_params_.AddParam("enable_fp16", ToolParam::Create<bool>(false));
40         default_params_.AddParam("allow_dynamic_dimensions", ToolParam::Create<bool>(false));
41     }
42 
~NnrtDelegateProvider()43     ~NnrtDelegateProvider() {};
44 
45     std::vector<Flag> CreateFlags(ToolParams* param) const final;
46 
47     void LogParams(const ToolParams& params, bool verbose) const final;
48 
49     TfLiteDelegatePtr CreateTfLiteDelegate(const ToolParams& params) const final;
50 
51     std::pair<TfLiteDelegatePtr, int32_t> CreateRankedTfLiteDelegate(const ToolParams& params) const final;
52 
GetName() const53     std::string GetName() const final
54     {
55         return "NNRT";
56     }
57 };
58 
59 REGISTER_DELEGATE_PROVIDER(NnrtDelegateProvider);
60 
CreateFlags(ToolParams * params) const61 std::vector<Flag> NnrtDelegateProvider::CreateFlags(ToolParams* params) const
62 {
63     std::vector<Flag> flags = {
64         CreateFlag<int32_t>("max_delegate_num", params, "Delegate max num limit, max_delegate_num <= 0 means no limit"),
65         CreateFlag<bool>("enable_fp16", params, "Whether to Infer model with FP16."),
66         CreateFlag<bool>("allow_dynamic_dimensions", params,
67             "Whether to allow dynamic dimension sizes without re-compilation."),
68         CreateFlag<std::string>("performance", params,
69         "Execution performance for nnrt delegate. "
70         "choose within [low, medium, high, extreme, default]."),
71         CreateFlag<std::string>("priority", params,
72         "The model execution priority in nnrt, and it "
73         "choose within [default, low, medium, high]."),
74         CreateFlag<std::string>("device", params,
75         "The name of the nnrt accelerator to use, "
76         "choose within [cpu, gpu, apu, nnrt-reference], "
77         "nnrt-reference means chosen automatically by nnrt."),
78         CreateFlag<std::string>("cache_dir", params, "The directory of load and save cache for delegate"),
79         CreateFlag<std::string>("model_token", params, "The file_name of load and save cache for delegate"),
80     };
81     return flags;
82 }
83 
LogParams(const ToolParams & params,bool verbose) const84 void NnrtDelegateProvider::LogParams(const ToolParams& params, bool verbose) const
85 {
86     LOG_TOOL_PARAM(params, bool, "use_nnrt", "Use NNRT", verbose);
87     if (!params.Get<bool>("use_nnrt")) {
88         return; // no use nnrt, return.
89     }
90 
91     LOG_TOOL_PARAM(params, std::string, "performance", "NNRT execution performance", verbose);
92     LOG_TOOL_PARAM(params, std::string, "priority", "NNRT execution priority", verbose);
93     LOG_TOOL_PARAM(params, std::string, "device", "NNRT accelerator name", verbose);
94     LOG_TOOL_PARAM(params, std::string, "cache_dir", "NNRT model cache directory", verbose);
95     LOG_TOOL_PARAM(params, std::string, "model_token", "NNRT model cache filename", verbose);
96     LOG_TOOL_PARAM(params, int32_t, "max_delegate_num", "NNRT delegate max partition", verbose);
97     LOG_TOOL_PARAM(params, bool, "enable_fp16", "NNRT allow fp16 inference", verbose);
98     LOG_TOOL_PARAM(params, bool, "allow_dynamic_dimensions", "NNRT allow dynamic dimensions", verbose);
99 }
100 
GetExecutionPerformance(const ToolParams & params,NnrtDelegate::Options & options)101 TfLiteStatus GetExecutionPerformance(const ToolParams& params, NnrtDelegate::Options& options)
102 {
103     std::string stringExecutionPerformance = params.Get<std::string>("performance");
104     if (stringExecutionPerformance.empty()) {
105         return kTfLiteOk; // no set performance
106     }
107 
108     OH_NN_PerformanceMode executionPerformance = OH_NN_PERFORMANCE_NONE;
109     if (stringExecutionPerformance == "low") {
110         executionPerformance = OH_NN_PERFORMANCE_LOW;
111     } else if (stringExecutionPerformance == "medium") {
112         executionPerformance = OH_NN_PERFORMANCE_MEDIUM;
113     } else if (stringExecutionPerformance == "high") {
114         executionPerformance = OH_NN_PERFORMANCE_HIGH;
115     } else if (stringExecutionPerformance == "extreme") {
116         executionPerformance = OH_NN_PERFORMANCE_EXTREME;
117     } else if (stringExecutionPerformance == "default") {
118         executionPerformance = OH_NN_PERFORMANCE_NONE;
119     } else {
120         TFLITE_LOG(ERROR) << "The provided value is not a valid nnrt execution performance.";
121         return kTfLiteError;
122     }
123     options.executionPerformance = executionPerformance;
124 
125     return kTfLiteOk;
126 }
127 
GetExecutionPriority(const ToolParams & params,NnrtDelegate::Options & options)128 TfLiteStatus GetExecutionPriority(const ToolParams& params, NnrtDelegate::Options& options)
129 {
130     std::string stringExecutionPriority = params.Get<std::string>("priority");
131     if (stringExecutionPriority.empty()) {
132         return kTfLiteOk; // no set priority
133     }
134 
135     OH_NN_Priority executionPriority = OH_NN_PRIORITY_MEDIUM;
136     if (stringExecutionPriority == "low") {
137         executionPriority = OH_NN_PRIORITY_LOW;
138     } else if (stringExecutionPriority == "medium") {
139         executionPriority = OH_NN_PRIORITY_MEDIUM;
140     } else if (stringExecutionPriority == "high") {
141         executionPriority = OH_NN_PRIORITY_HIGH;
142     } else if (stringExecutionPriority == "default") {
143         executionPriority = OH_NN_PRIORITY_MEDIUM;
144     } else {
145         TFLITE_LOG(ERROR) << "The provided value is not a valid nnrt execution priority.";
146         return kTfLiteError;
147     }
148     options.executionPriority = executionPriority;
149 
150     return kTfLiteOk;
151 }
152 
MapParams(const ToolParams & params,NnrtDelegate::Options & options)153 TfLiteStatus MapParams(const ToolParams& params, NnrtDelegate::Options& options)
154 {
155     std::string acceleratorName = params.Get<std::string>("device");
156     if (!acceleratorName.empty()) {
157         options.acceleratorName = acceleratorName;
158     }
159 
160     if (params.GetParam("max_delegate_num") != nullptr) {
161         options.maxNumberDelegatedPartitions = params.Get<int32_t>("max_delegate_num");
162     }
163 
164     std::string cacheDir = params.Get<std::string>("cache_dir");
165     if (!cacheDir.empty()) {
166         options.cacheDir = cacheDir;
167     }
168 
169     std::string modelToken = params.Get<std::string>("model_token");
170     if (!modelToken.empty()) {
171         options.modelToken = modelToken;
172     }
173 
174     if (params.Get<bool>("enable_fp16")) {
175         options.enableFp16 = true;
176     }
177 
178     if (params.Get<bool>("allow_dynamic_dimensions")) {
179         options.allowDynamicDimensions = true;
180     }
181 
182     return kTfLiteOk;
183 }
184 
CreateTfLiteDelegate(const ToolParams & params) const185 TfLiteDelegatePtr NnrtDelegateProvider::CreateTfLiteDelegate(const ToolParams& params) const
186 {
187     TfLiteDelegatePtr delegate(nullptr, [](TfLiteDelegate*) {});
188     if (!params.Get<bool>("use_nnrt")) {
189         return delegate;
190     }
191 
192     NnrtDelegate::Options options;
193     TFLITE_TOOLS_CHECK(MapParams(params, options) == kTfLiteOk) << "Map params to NNRT Delegate options failed.";
194     TFLITE_TOOLS_CHECK(GetExecutionPerformance(params, options) == kTfLiteOk) <<
195         "Create TfLite NNRT Delegate failed.";
196     TFLITE_TOOLS_CHECK(GetExecutionPriority(params, options) == kTfLiteOk) << "Create TfLite NNRT Delegate failed.";
197 
198     const auto* nnrtImpl = NnrtImplementation();
199     if (!nnrtImpl->nnrtExists) {
200         TFLITE_LOG(WARN) << "NNRT acceleration is unsupported on this platform.";
201         return delegate;
202     }
203 
204     return TfLiteDelegatePtr(new (std::nothrow) NnrtDelegate(nnrtImpl, options),
205         [](TfLiteDelegate* delegate) { delete reinterpret_cast<NnrtDelegate*>(delegate); });
206 }
207 
CreateRankedTfLiteDelegate(const ToolParams & params) const208 std::pair<TfLiteDelegatePtr, int32_t> NnrtDelegateProvider::CreateRankedTfLiteDelegate(const ToolParams& params) const
209 {
210     auto ptr = CreateTfLiteDelegate(params);
211     LogParams(params, false);
212     return std::make_pair(std::move(ptr), params.GetPosition<bool>("use_nnrt"));
213 }
214 } // namespace tools
215 } // namespace tflite