1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "interfaces/kits/c/neural_network_runtime/neural_network_core.h"
17 
18 #include <string>
19 #include <securec.h>
20 #include <utility>
21 #include <unordered_map>
22 #include <future>
23 #include <thread>
24 
25 #include "common/log.h"
26 #include "executor.h"
27 #include "tensor.h"
28 #include "compilation.h"
29 #include "backend_manager.h"
30 #include "nnrt_client.h"
31 
32 using namespace OHOS::NeuralNetworkRuntime;
33 #define NNRT_API __attribute__((visibility("default")))
34 
OH_NNDevice_GetAllDevicesID(const size_t ** allDevicesID,uint32_t * deviceCount)35 NNRT_API OH_NN_ReturnCode OH_NNDevice_GetAllDevicesID(const size_t **allDevicesID, uint32_t *deviceCount)
36 {
37     if (allDevicesID == nullptr) {
38         LOGE("OH_NNDevice_GetAllDevicesID failed, passed nullptr to allDevicesID.");
39         return OH_NN_INVALID_PARAMETER;
40     }
41 
42     if ((*allDevicesID) != nullptr) {
43         LOGE("OH_NNDevice_GetAllDevicesID failed, *allDevicesID should be nullptr.");
44         return OH_NN_INVALID_PARAMETER;
45     }
46 
47     if (deviceCount == nullptr) {
48         LOGE("OH_NNDevice_GetAllDevicesID failed, passed nullptr to deviceCount.");
49         return OH_NN_INVALID_PARAMETER;
50     }
51 
52     BackendManager& backendManager = BackendManager::GetInstance();
53     const std::vector<size_t>& allDevices = backendManager.GetAllBackendsID();
54 
55     if (allDevices.empty()) {
56         LOGW("OH_NNDevice_GetAllDevicesID got no device.");
57         *allDevicesID = nullptr;
58         *deviceCount = 0;
59         return OH_NN_SUCCESS;
60     }
61 
62     *allDevicesID = allDevices.data();
63     // allDevices.size() will not exceed UINT32_MAX, it is safe to cast to uint32_t.
64     *deviceCount = static_cast<uint32_t>(allDevices.size());
65 
66     return OH_NN_SUCCESS;
67 }
68 
OH_NNDevice_GetName(size_t deviceID,const char ** name)69 NNRT_API OH_NN_ReturnCode OH_NNDevice_GetName(size_t deviceID, const char **name)
70 {
71     if (name == nullptr) {
72         LOGE("OH_NNDevice_GetName failed, passed nullptr to name.");
73         return OH_NN_INVALID_PARAMETER;
74     }
75 
76     if ((*name) != nullptr) {
77         LOGE("OH_NNDevice_GetName failed, *name should be nullptr.");
78         return OH_NN_INVALID_PARAMETER;
79     }
80 
81     BackendManager& backendManager = BackendManager::GetInstance();
82     const std::string& backendName = backendManager.GetBackendName(deviceID);
83     if (backendName.empty()) {
84         LOGE("OH_NNDevice_GetName failed, error happened when getting name of deviceID.");
85         *name = nullptr;
86         return OH_NN_FAILED;
87     }
88 
89     *name = backendName.data();
90     return OH_NN_SUCCESS;
91 }
92 
OH_NNDevice_GetType(size_t deviceID,OH_NN_DeviceType * deviceType)93 NNRT_API OH_NN_ReturnCode OH_NNDevice_GetType(size_t deviceID, OH_NN_DeviceType* deviceType)
94 {
95     BackendManager& backendManager = BackendManager::GetInstance();
96     std::shared_ptr<Backend> backend = backendManager.GetBackend(deviceID);
97     if (backend == nullptr) {
98         LOGE("OH_NNDevice_GetType failed, passed invalid deviceID.");
99         return OH_NN_INVALID_PARAMETER;
100     }
101 
102     if (deviceType == nullptr) {
103         LOGE("OH_NNDevice_GetType failed, passed nullptr to deviceType.");
104         return OH_NN_INVALID_PARAMETER;
105     }
106 
107     OH_NN_ReturnCode ret = backend->GetBackendType(*deviceType);
108     if (ret != OH_NN_SUCCESS) {
109         LOGE("OH_NNDevice_GetType failed.");
110         return ret;
111     }
112     return OH_NN_SUCCESS;
113 }
114 
OH_NNCompilation_Construct(const OH_NNModel * model)115 NNRT_API OH_NNCompilation *OH_NNCompilation_Construct(const OH_NNModel *model)
116 {
117     if (model == nullptr) {
118         LOGE("OH_NNCompilation_Construct failed, passed nullptr to model.");
119         return nullptr;
120     }
121 
122     Compilation *compilation = new (std::nothrow) Compilation();
123     if (compilation == nullptr) {
124         LOGE("OH_NNCompilation_Construct failed, please check whether it has enough memory.");
125         return nullptr;
126     }
127 
128     compilation->nnModel = const_cast<void*>(reinterpret_cast<const void*>(model));
129 
130     OH_NNCompilation* nnCompilation = reinterpret_cast<OH_NNCompilation*>(compilation);
131     return nnCompilation;
132 }
133 
OH_NNCompilation_ConstructWithOfflineModelFile(const char * modelPath)134 NNRT_API OH_NNCompilation *OH_NNCompilation_ConstructWithOfflineModelFile(const char *modelPath)
135 {
136     if (modelPath == nullptr) {
137         LOGE("OH_NNCompilation_ConstructWithOfflineModelFile failed, passed nullptr to modelPath.");
138         return nullptr;
139     }
140 
141     Compilation *compilation = new (std::nothrow) Compilation();
142     if (compilation == nullptr) {
143         LOGE("OH_NNCompilation_ConstructWithOfflineModelFile failed, please check whether it has enough memory.");
144         return nullptr;
145     }
146 
147     compilation->offlineModelPath = const_cast<char*>(modelPath);
148     OH_NNCompilation* nnCompilation = reinterpret_cast<OH_NNCompilation*>(compilation);
149 
150     return nnCompilation;
151 }
152 
OH_NNCompilation_ConstructWithOfflineModelBuffer(const void * modelBuffer,size_t modelSize)153 NNRT_API OH_NNCompilation *OH_NNCompilation_ConstructWithOfflineModelBuffer(const void *modelBuffer, size_t modelSize)
154 {
155     if (modelBuffer == nullptr) {
156         LOGE("OH_NNCompilation_ConstructWithOfflineModelBuffer failed, modelBuffer is nullptr.");
157         return nullptr;
158     }
159 
160     if (modelSize == static_cast<size_t>(0)) {
161         LOGE("OH_NNCompilation_ConstructWithOfflineModelBuffer failed, modelSize is 0.");
162         return nullptr;
163     }
164 
165     Compilation *compilation = new (std::nothrow) Compilation();
166     if (compilation == nullptr) {
167         LOGE("OH_NNCompilation_ConstructWithOfflineModelBuffer failed, please check whether it has enough memory.");
168         return nullptr;
169     }
170 
171     compilation->offlineModelBuffer.first = const_cast<void*>(modelBuffer);
172     compilation->offlineModelBuffer.second = modelSize;
173     OH_NNCompilation* nnCompilation = reinterpret_cast<OH_NNCompilation*>(compilation);
174 
175     return nnCompilation;
176 }
177 
OH_NNCompilation_ConstructForCache()178 NNRT_API OH_NNCompilation *OH_NNCompilation_ConstructForCache()
179 {
180     Compilation *compilation = new (std::nothrow) Compilation();
181     if (compilation == nullptr) {
182         LOGE("OH_NNCompilation_ConstructForCache failed, please check whether it has enough memory.");
183         return nullptr;
184     }
185 
186     OH_NNCompilation* nnCompilation = reinterpret_cast<OH_NNCompilation*>(compilation);
187     return nnCompilation;
188 }
189 
OH_NNCompilation_ExportCacheToBuffer(OH_NNCompilation * compilation,const void * buffer,size_t length,size_t * modelSize)190 NNRT_API OH_NN_ReturnCode OH_NNCompilation_ExportCacheToBuffer(OH_NNCompilation *compilation,
191                                                                const void *buffer,
192                                                                size_t length,
193                                                                size_t *modelSize)
194 {
195     if (compilation == nullptr) {
196         LOGE("OH_NNCompilation_ExportCacheToBuffer failed, compilation is nullptr.");
197         return OH_NN_INVALID_PARAMETER;
198     }
199 
200     if (buffer == nullptr) {
201         LOGE("OH_NNCompilation_ExportCacheToBuffer failed, buffer is nullptr.");
202         return OH_NN_INVALID_PARAMETER;
203     }
204 
205     if (length == static_cast<size_t>(0)) {
206         LOGE("OH_NNCompilation_ExportCacheToBuffer failed, pass length equals to 0.");
207         return OH_NN_INVALID_PARAMETER;
208     }
209 
210     if (modelSize == nullptr) {
211         LOGE("OH_NNCompilation_ExportCacheToBuffer failed, modelSize is nullptr.");
212         return OH_NN_INVALID_PARAMETER;
213     }
214 
215     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
216     if (compilationImpl->compiler == nullptr) {
217         LOGE("OH_NNCompilation_ExportCacheToBuffer failed, should call OH_NNCompilation_Build before export cache.");
218         return OH_NN_INVALID_PARAMETER;
219     }
220 
221     OH_NN_ReturnCode ret = compilationImpl->compiler->SaveToCacheBuffer(buffer, length, modelSize);
222     if (ret != OH_NN_SUCCESS) {
223         LOGE("OH_NNCompilation_ExportCacheToBuffer failed, fail to save cache to buffer.");
224     }
225 
226     return ret;
227 }
228 
OH_NNCompilation_ImportCacheFromBuffer(OH_NNCompilation * compilation,const void * buffer,size_t modelSize)229 NNRT_API OH_NN_ReturnCode OH_NNCompilation_ImportCacheFromBuffer(OH_NNCompilation *compilation,
230                                                                  const void *buffer,
231                                                                  size_t modelSize)
232 {
233     if (compilation == nullptr) {
234         LOGE("OH_NNCompilation_ImportCacheFromBuffer failed, compilation is nullptr.");
235         return OH_NN_INVALID_PARAMETER;
236     }
237 
238     if (buffer == nullptr) {
239         LOGE("OH_NNCompilation_ImportCacheFromBuffer failed, buffer is nullptr.");
240         return OH_NN_INVALID_PARAMETER;
241     }
242 
243     if (modelSize == static_cast<size_t>(0)) {
244         LOGE("OH_NNCompilation_ImportCacheFromBuffer failed, modelSize is 0.");
245         return OH_NN_INVALID_PARAMETER;
246     }
247 
248     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
249     compilationImpl->offlineModelBuffer.first = const_cast<void*>(buffer);
250     compilationImpl->offlineModelBuffer.second = modelSize;
251 
252     return OH_NN_SUCCESS;
253 }
254 
OH_NNCompilation_AddExtensionConfig(OH_NNCompilation * compilation,const char * configName,const void * configValue,const size_t configValueSize)255 NNRT_API OH_NN_ReturnCode OH_NNCompilation_AddExtensionConfig(OH_NNCompilation *compilation,
256                                                               const char *configName,
257                                                               const void *configValue,
258                                                               const size_t configValueSize)
259 {
260     if (compilation == nullptr) {
261         LOGE("OH_NNCompilation_AddExtensionConfig failed, compilation is nullptr.");
262         return OH_NN_INVALID_PARAMETER;
263     }
264 
265     if (configName == nullptr) {
266         LOGE("OH_NNCompilation_AddExtensionConfig failed, configName is nullptr.");
267         return OH_NN_INVALID_PARAMETER;
268     }
269 
270     if (configValue == nullptr) {
271         LOGE("OH_NNCompilation_AddExtensionConfig failed, configValue is nullptr.");
272         return OH_NN_INVALID_PARAMETER;
273     }
274 
275     if (configValueSize == static_cast<size_t>(0)) {
276         LOGE("OH_NNCompilation_AddExtensionConfig failed, configValueSize is 0.");
277         return OH_NN_INVALID_PARAMETER;
278     }
279 
280     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
281 
282     std::string configNameStr = configName;
283     if (configNameStr.empty()) {
284         LOGE("OH_NNCompilation_AddExtensionConfig failed, configName is empty.");
285         return OH_NN_INVALID_PARAMETER;
286     }
287 
288     std::vector<char> configValueVec(configValueSize, '0');
289     void* configValueAddr = reinterpret_cast<void*>(configValueVec.data());
290     errno_t ret = memcpy_s(configValueAddr, configValueVec.size(), configValue, configValueSize);
291     if (ret != EOK) {
292         LOGE("OH_NNCompilation_AddExtensionConfig failed, copy config value failed.");
293         return OH_NN_FAILED;
294     }
295 
296     auto iter = compilationImpl->configs.find(configNameStr);
297     if (iter == compilationImpl->configs.end()) {
298         compilationImpl->configs.emplace(configNameStr, configValueVec);
299     } else {
300         iter->second.emplace_back('|');
301         iter->second.insert(iter->second.end(), configValueVec.begin(), configValueVec.end());
302     }
303 
304     return OH_NN_SUCCESS;
305 }
306 
OH_NNCompilation_SetDevice(OH_NNCompilation * compilation,size_t deviceID)307 NNRT_API OH_NN_ReturnCode OH_NNCompilation_SetDevice(OH_NNCompilation *compilation, size_t deviceID)
308 {
309     if (compilation == nullptr) {
310         LOGE("OH_NNCompilation_SetDevice failed, compilation is nullptr.");
311         return OH_NN_INVALID_PARAMETER;
312     }
313 
314     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
315     compilationImpl->backendID = deviceID;
316 
317     return OH_NN_SUCCESS;
318 }
319 
OH_NNCompilation_SetCache(OH_NNCompilation * compilation,const char * cachePath,uint32_t version)320 NNRT_API OH_NN_ReturnCode OH_NNCompilation_SetCache(OH_NNCompilation *compilation,
321                                                     const char *cachePath,
322                                                     uint32_t version)
323 {
324     if (compilation == nullptr) {
325         LOGE("OH_NNCompilation_SetCache failed, compilation is nullptr.");
326         return OH_NN_INVALID_PARAMETER;
327     }
328 
329     if (cachePath == nullptr) {
330         LOGE("OH_NNCompilation_SetCache failed, cachePath is nullptr.");
331         return OH_NN_INVALID_PARAMETER;
332     }
333 
334     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
335     compilationImpl->cachePath = const_cast<char*>(cachePath);
336     compilationImpl->cacheVersion = version;
337 
338     return OH_NN_SUCCESS;
339 }
340 
OH_NNCompilation_SetPerformanceMode(OH_NNCompilation * compilation,OH_NN_PerformanceMode performanceMode)341 NNRT_API OH_NN_ReturnCode OH_NNCompilation_SetPerformanceMode(OH_NNCompilation *compilation,
342                                                               OH_NN_PerformanceMode performanceMode)
343 {
344     if (compilation == nullptr) {
345         LOGE("OH_NNCompilation_SetPerformanceMode failed, compilation is nullptr.");
346         return OH_NN_INVALID_PARAMETER;
347     }
348 
349     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
350     compilationImpl->performance = performanceMode;
351 
352     if (compilationImpl->compiler != nullptr) {
353         OH_NN_ReturnCode ret = compilationImpl->compiler->SetPerformance(performanceMode);
354         if (ret != OH_NN_SUCCESS) {
355             LOGE("OH_NNCompilation_SetPerformanceMode failed.");
356             return ret;
357         }
358     }
359 
360     return OH_NN_SUCCESS;
361 }
362 
OH_NNCompilation_SetPriority(OH_NNCompilation * compilation,OH_NN_Priority priority)363 NNRT_API OH_NN_ReturnCode OH_NNCompilation_SetPriority(OH_NNCompilation *compilation, OH_NN_Priority priority)
364 {
365     if (compilation == nullptr) {
366         LOGE("OH_NNCompilation_SetPriority failed, compilation is nullptr.");
367         return OH_NN_INVALID_PARAMETER;
368     }
369 
370     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
371     compilationImpl->priority = priority;
372 
373     if (compilationImpl->compiler != nullptr) {
374         OH_NN_ReturnCode ret = compilationImpl->compiler->SetPriority(priority);
375         if (ret != OH_NN_SUCCESS) {
376             LOGE("OH_NNCompilation_SetPriority failed.");
377             return ret;
378         }
379     }
380 
381     return OH_NN_SUCCESS;
382 }
383 
OH_NNCompilation_EnableFloat16(OH_NNCompilation * compilation,bool enableFloat16)384 NNRT_API OH_NN_ReturnCode OH_NNCompilation_EnableFloat16(OH_NNCompilation *compilation, bool enableFloat16)
385 {
386     if (compilation == nullptr) {
387         LOGE("OH_NNCompilation_EnableFloat16 failed, compilation is nullptr.");
388         return OH_NN_INVALID_PARAMETER;
389     }
390 
391     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
392     compilationImpl->enableFp16 = enableFloat16;
393 
394     return OH_NN_SUCCESS;
395 }
396 
CreateCompiler(Compilation * compilation,Compiler ** compiler)397 OH_NN_ReturnCode CreateCompiler(Compilation* compilation, Compiler** compiler)
398 {
399     if (compilation == nullptr) {
400         LOGE("CreateCompiler failed, compilation is nullptr.");
401         return OH_NN_INVALID_PARAMETER;
402     }
403 
404     if (compiler == nullptr) {
405         LOGE("CreateCompiler failed, compiler is nullptr.");
406         return OH_NN_INVALID_PARAMETER;
407     }
408 
409     BackendManager& manager = BackendManager::GetInstance();
410     std::shared_ptr<Backend> backend = manager.GetBackend(compilation->backendID);
411     if (backend == nullptr) {
412         LOGE("CreateCompiler failed, fail to get backend %{public}zu.", compilation->backendID);
413         return OH_NN_FAILED;
414     }
415 
416     *compiler = backend->CreateCompiler(compilation);
417     if (*compiler == nullptr) {
418         LOGE("CreateCompiler failed, fail to create compiler.");
419         return OH_NN_FAILED;
420     }
421 
422     return OH_NN_SUCCESS;
423 }
424 
SetCompilationOptions(Compilation * compilation)425 OH_NN_ReturnCode SetCompilationOptions(Compilation* compilation)
426 {
427     if (compilation == nullptr) {
428         LOGE("SetCompilationOptions failed, compilation is nullptr.");
429         return OH_NN_INVALID_PARAMETER;
430     }
431 
432     if (compilation->compiler == nullptr) {
433         LOGE("SetCompilationOptions failed, compiler is nullptr.");
434         return OH_NN_INVALID_PARAMETER;
435     }
436 
437     OH_NN_ReturnCode ret = OH_NN_SUCCESS;
438     if (compilation->cachePath != nullptr) {
439         ret = compilation->compiler->SetCacheDir(compilation->cachePath, compilation->cacheVersion);
440         if (ret != OH_NN_SUCCESS) {
441             LOGE("SetCompilationOptions failed, fail to set cache dir.");
442             return ret;
443         }
444     }
445 
446     ret = compilation->compiler->SetEnableFp16(compilation->enableFp16);
447     if (ret != OH_NN_SUCCESS) {
448         LOGE("SetCompilationOptions failed, fail to set enable fp16.");
449         return ret;
450     }
451 
452     ret = compilation->compiler->SetPerformance(compilation->performance);
453     if (ret != OH_NN_SUCCESS) {
454         LOGE("SetCompilationOptions failed, fail to set performance.");
455         return ret;
456     }
457 
458     ret = compilation->compiler->SetPriority(compilation->priority);
459     if (ret != OH_NN_SUCCESS) {
460         LOGE("SetCompilationOptions failed, fail to set priority.");
461         return ret;
462     }
463 
464     ret = compilation->compiler->SetExtensionConfig(compilation->configs);
465     if ((ret != OH_NN_SUCCESS) && (ret != OH_NN_UNSUPPORTED)) {
466         LOGE("SetCompilationOptions failed, fail to set extenstion configs.");
467         return ret;
468     }
469 
470     ret = compilation->compiler->SetOptions(compilation->options);
471     if ((ret != OH_NN_SUCCESS) && (ret != OH_NN_UNSUPPORTED)) {
472         LOGE("SetCompilationOptions failed, fail to set extenstion options.");
473         return ret;
474     }
475 
476     return OH_NN_SUCCESS;
477 }
478 
CheckExceedRamLimit(const Compilation * compilation,bool & isExceedRamLimit)479 OH_NN_ReturnCode CheckExceedRamLimit(const Compilation* compilation, bool& isExceedRamLimit)
480 {
481     if (compilation == nullptr) {
482         LOGE("CheckExceedRamLimit failed, compilation is nullptr.");
483         return OH_NN_INVALID_PARAMETER;
484     }
485 
486     NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
487     if (!nnrtService.IsServiceAvaliable()) {
488         LOGW("CheckExceedRamLimit failed, fail to get nnrt service, skip check exceed ram limit.");
489         return OH_NN_SUCCESS;
490     }
491 
492     if (nnrtService.CheckModelSizeFromBuffer == nullptr) {
493         LOGE("CheckExceedRamLimit failed, nnrtService CheckModelSizeFromBuffer func is nullptr.");
494         return OH_NN_INVALID_PARAMETER;
495     }
496 
497     if (nnrtService.CheckModelSizeFromModel == nullptr) {
498         LOGE("CheckExceedRamLimit failed, nnrtService CheckModelSizeFromModel func is nullptr.");
499         return OH_NN_INVALID_PARAMETER;
500     }
501 
502     if (nnrtService.CheckModelSizeFromPath == nullptr) {
503         LOGE("CheckExceedRamLimit failed, nnrtService CheckModelSizeFromPath func is nullptr.");
504         return OH_NN_INVALID_PARAMETER;
505     }
506 
507     int ret = static_cast<OH_NN_ReturnCode>(OH_NN_SUCCESS);
508     if (compilation->nnModel != nullptr) {
509         ret = nnrtService.CheckModelSizeFromModel(compilation->nnModel, isExceedRamLimit);
510     } else if (compilation->offlineModelPath != nullptr) {
511         ret = nnrtService.CheckModelSizeFromPath(compilation->offlineModelPath, isExceedRamLimit);
512     } else if (compilation->cachePath != nullptr) {
513         ret = nnrtService.CheckModelSizeFromPath(compilation->cachePath, isExceedRamLimit);
514     } else if ((compilation->offlineModelBuffer.first != nullptr) && \
515                (compilation->offlineModelBuffer.second != size_t(0))) {
516         ret = nnrtService.CheckModelSizeFromBuffer(
517             compilation->offlineModelBuffer.first, compilation->offlineModelBuffer.second, isExceedRamLimit);
518     } else if ((compilation->cacheBuffer.first != nullptr) && \
519                (compilation->cacheBuffer.second != size_t(0))) {
520         ret = nnrtService.CheckModelSizeFromBuffer(
521             compilation->cacheBuffer.first, compilation->cacheBuffer.second, isExceedRamLimit);
522     } else {
523         LOGE("CheckExceedRamLimit failed, no available model to check.");
524         return OH_NN_INVALID_PARAMETER;
525     }
526 
527     if (ret != static_cast<OH_NN_ReturnCode>(OH_NN_SUCCESS)) {
528         LOGE("CheckExceedRamLimit failed, some error happened when check if model exceed ram limit.");
529         return OH_NN_INVALID_PARAMETER;
530     }
531 
532     return OH_NN_SUCCESS;
533 }
534 
AuthenticateModel(const Compilation * compilation)535 OH_NN_ReturnCode AuthenticateModel(const Compilation* compilation)
536 {
537     bool isExceedRamLimit = false;
538     OH_NN_ReturnCode retCode = CheckExceedRamLimit(compilation, isExceedRamLimit);
539     if (retCode != OH_NN_SUCCESS) {
540         LOGE("AuthenticateModel failed, fail to check if model exceed ram limit.");
541         return retCode;
542     }
543 
544     if (!isExceedRamLimit) {
545         LOGI("Model accupy memory less then limit, no need authenticating.");
546         return OH_NN_SUCCESS; // If model ram is less than max limit, no need authenticating.
547     }
548 
549     NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
550     if (!nnrtService.IsServiceAvaliable()) {
551         LOGW("AuthenticateModel failed, fail to get nnrt service, skip authenticating.");
552         return OH_NN_SUCCESS;
553     }
554 
555     if (nnrtService.IsSupportAuthentication == nullptr) {
556         LOGE("Authentication failed, nnrtService IsSupportAuthentication func is nullptr.");
557         return OH_NN_INVALID_PARAMETER;
558     }
559 
560     bool supportStat = false;
561     int ret = nnrtService.IsSupportAuthentication(&supportStat);
562     if (ret != static_cast<int>(OH_NN_SUCCESS)) {
563         LOGE("Authentication failed, some error happened when judge if support authenticating.");
564         return static_cast<OH_NN_ReturnCode>(ret);
565     }
566 
567     if (!supportStat) {
568         LOGW("device not support authenticating, jumper over authenticating model.");
569         return OH_NN_SUCCESS;
570     }
571 
572     if (nnrtService.Authentication == nullptr) {
573         LOGE("Authentication failed, nnrtService Authentication func is nullptr.");
574         return OH_NN_INVALID_PARAMETER;
575     }
576     ret = nnrtService.Authentication(compilation->callingPid);
577     if (ret != static_cast<int>(OH_NN_SUCCESS)) {
578         LOGE("Authentication failed, input model cannot run by npu.");
579         return static_cast<OH_NN_ReturnCode>(ret);
580     }
581 
582     return OH_NN_SUCCESS;
583 }
584 
Authentication(Compilation ** compilation)585 OH_NN_ReturnCode Authentication(Compilation** compilation)
586 {
587     if (compilation == nullptr) {
588         LOGE("Authentication failed, compilation is nullptr.");
589         return OH_NN_INVALID_PARAMETER;
590     }
591 
592     Compilation* compilationImpl = *compilation;
593     if (compilationImpl == nullptr) {
594         LOGE("Authentication failed, compilation implementation is nullptr.");
595         return OH_NN_INVALID_PARAMETER;
596     }
597 
598     auto iter = compilationImpl->configs.find("callingPid");
599     if (iter == compilationImpl->configs.end()) {
600         LOGE("missing 'callingPid' parameter in compilation configs.");
601     } else {
602         compilationImpl->callingPid = std::atoi((iter->second).data());
603     }
604 
605     const NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
606     if (!nnrtService.IsServiceAvaliable()) {
607         LOGW("Authentication failed, fail to get nnrt service, skip Authentication.");
608         return OH_NN_SUCCESS;
609     }
610 
611     OH_NN_ReturnCode ret = AuthenticateModel(compilationImpl);
612     if (ret != OH_NN_SUCCESS) {
613         LOGE("Authentication failed, fail to authenticate model.");
614         return ret;
615     }
616 
617     LOGI("Authentication success.");
618     return OH_NN_SUCCESS;
619 }
620 
621 namespace {
GetNnrtModelId(Compilation * compilationImpl,NNRtServiceApi & nnrtService)622 OH_NN_ReturnCode GetNnrtModelId(Compilation* compilationImpl, NNRtServiceApi& nnrtService)
623 {
624     std::string modelName;
625     OH_NN_ReturnCode retCode = compilationImpl->compiler->GetModelName(modelName);
626     if (retCode != OH_NN_SUCCESS) {
627         LOGW("GetModelName is failed.");
628     }
629 
630     if (compilationImpl->nnModel != nullptr) {
631         compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath,
632             modelName.c_str());
633         if (compilationImpl->nnrtModelID == 0) {
634             compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromModel(compilationImpl->nnModel);
635         }
636     } else if (compilationImpl->offlineModelPath != nullptr) {
637         compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromPath(compilationImpl->offlineModelPath);
638     } else if (compilationImpl->cachePath != nullptr) {
639         compilationImpl->nnrtModelID =
640             nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath, modelName.c_str());
641     } else if ((compilationImpl->offlineModelBuffer.first != nullptr) && \
642                (compilationImpl->offlineModelBuffer.second != size_t(0))) {
643         compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer(
644             compilationImpl->offlineModelBuffer.first, compilationImpl->offlineModelBuffer.second);
645     } else if ((compilationImpl->cacheBuffer.first != nullptr) && \
646                (compilationImpl->cacheBuffer.second != size_t(0))) {
647         compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer(
648             compilationImpl->cacheBuffer.first, compilationImpl->cacheBuffer.second);
649     } else {
650         LOGE("GetModelId failed, no available model to set modelId, please check.");
651         return OH_NN_INVALID_PARAMETER;
652     }
653 
654     return OH_NN_SUCCESS;
655 }
656 }
657 
GetModelId(Compilation ** compilation)658 OH_NN_ReturnCode GetModelId(Compilation** compilation)
659 {
660     if (compilation == nullptr) {
661         LOGE("GetModelId failed, compilation is nullptr.");
662         return OH_NN_INVALID_PARAMETER;
663     }
664 
665     Compilation* compilationImpl = *compilation;
666     if (compilationImpl == nullptr) {
667         LOGE("GetModelId failed, compilation implementation is nullptr.");
668         return OH_NN_INVALID_PARAMETER;
669     }
670 
671     NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
672     if (!nnrtService.IsServiceAvaliable()) {
673         LOGW("GetModelId failed, fail to get nnrt service, skip get modelId.");
674         return OH_NN_SUCCESS;
675     }
676 
677     if (nnrtService.GetNNRtModelIDFromPath == nullptr) {
678         LOGE("GetModelId failed, nnrtService GetNNRtModelIDFromPath func is nullptr.");
679         return OH_NN_INVALID_PARAMETER;
680     }
681 
682     if (nnrtService.GetNNRtModelIDFromBuffer == nullptr) {
683         LOGE("GetModelId failed, nnrtService GetNNRtModelIDFromBuffer func is nullptr.");
684         return OH_NN_INVALID_PARAMETER;
685     }
686 
687     if (nnrtService.GetNNRtModelIDFromModel == nullptr) {
688         LOGE("GetModelId failed, nnrtService GetNNRtModelIDFromModel func is nullptr.");
689         return OH_NN_INVALID_PARAMETER;
690     }
691 
692     auto ret = GetNnrtModelId(compilationImpl, nnrtService);
693     if (ret != OH_NN_SUCCESS) {
694         LOGE("GetNnrtModelId is failed.");
695         return ret;
696     }
697 
698     return OH_NN_SUCCESS;
699 }
700 
OH_NNCompilation_Build(OH_NNCompilation * compilation)701 NNRT_API OH_NN_ReturnCode OH_NNCompilation_Build(OH_NNCompilation *compilation)
702 {
703     if (compilation == nullptr) {
704         LOGE("OH_NNCompilation_Build failed, compilation is nullptr.");
705         return OH_NN_INVALID_PARAMETER;
706     }
707 
708     Compilation* compilationImpl = reinterpret_cast<Compilation*>(compilation);
709 
710     if (((compilationImpl->nnModel != nullptr) && (compilationImpl->offlineModelPath != nullptr)) ||
711         ((compilationImpl->nnModel != nullptr) &&
712          ((compilationImpl->offlineModelBuffer.first != nullptr) ||
713           (compilationImpl->offlineModelBuffer.second != static_cast<size_t>(0)))) ||
714         ((compilationImpl->offlineModelPath != nullptr) &&
715          ((compilationImpl->offlineModelBuffer.first != nullptr) ||
716           (compilationImpl->offlineModelBuffer.second != static_cast<size_t>(0))))) {
717         LOGE("OH_NNCompilation_Build failed, find multi model to build compilation.");
718         return OH_NN_INVALID_PARAMETER;
719     }
720 
721     OH_NN_ReturnCode ret = OH_NN_SUCCESS;
722     if (compilationImpl->compiler != nullptr) {
723         LOGE("OH_NNCompilation_Build failed, the compiler in compilation is not nullptr, "
724              "please input a new compilation.");
725         return OH_NN_INVALID_PARAMETER;
726     }
727 
728     Compiler* compiler = nullptr;
729     ret = CreateCompiler(compilationImpl, &compiler);
730     if (ret != OH_NN_SUCCESS) {
731         LOGE("OH_NNCompilation_Build failed, fail to create compiler.");
732         return ret;
733     }
734     compilationImpl->compiler = compiler;
735 
736     ret = SetCompilationOptions(compilationImpl);
737     if (ret != OH_NN_SUCCESS) {
738         LOGE("OH_NNCompilation_Build failed, fail to create compiler.");
739         return ret;
740     }
741 
742     ret = Authentication(&compilationImpl);
743     if (ret != OH_NN_SUCCESS) {
744         LOGE("OH_NNCompilation_Build failed, fail to create compiler.");
745         return ret;
746     }
747 
748     bool isBuild = compilationImpl->compiler->IsBuild();
749     if (isBuild) {
750         LOGE("OH_NNCompilation_Build failed, compilation has been built, don't build again.");
751         return OH_NN_OPERATION_FORBIDDEN;
752     }
753 
754     ret = compilationImpl->compiler->Build();
755     if (ret != OH_NN_SUCCESS) {
756         LOGE("OH_NNCompilation_Build failed, fail to build compilation.");
757         return ret;
758     }
759 
760     ret = GetModelId(&compilationImpl);
761     if (ret != OH_NN_SUCCESS) {
762         LOGE("OH_NNCompilation_Build failed, fail to get modelId.");
763         return ret;
764     }
765 
766     return OH_NN_SUCCESS;
767 }
768 
OH_NNCompilation_Destroy(OH_NNCompilation ** compilation)769 NNRT_API void OH_NNCompilation_Destroy(OH_NNCompilation **compilation)
770 {
771     if (compilation == nullptr) {
772         LOGE("OH_NNCompilation_Destroy failed, compilation is nullptr.");
773         return;
774     }
775 
776     if (*compilation == nullptr) {
777         LOGE("OH_NNCompilation_Destroy failed, compilation is nullptr.");
778         return;
779     }
780 
781     Compilation* compilationImpl = reinterpret_cast<Compilation*>(*compilation);
782     if (compilationImpl->compiler != nullptr) {
783         BackendManager& manager = BackendManager::GetInstance();
784         std::shared_ptr<Backend> backend = manager.GetBackend(compilationImpl->backendID);
785         if (backend == nullptr) {
786             LOGE("OH_NNCompilation_Destroy failed, fail to get backend %{public}zu.", compilationImpl->backendID);
787             return;
788         }
789 
790         OH_NN_ReturnCode ret = backend->DestroyCompiler(compilationImpl->compiler);
791         if (ret != OH_NN_SUCCESS) {
792             LOGE("OH_NNCompilation_Destroy failed, fail to destroy compiler.");
793             return;
794         }
795     }
796 
797     delete compilationImpl;
798     *compilation = nullptr;
799 }
800 
OH_NNTensorDesc_Create()801 NNRT_API NN_TensorDesc *OH_NNTensorDesc_Create()
802 {
803     TensorDesc *tensorDescImpl = new (std::nothrow) TensorDesc();
804     if (tensorDescImpl == nullptr) {
805         LOGE("OH_NNTensorDesc_Create failed, failed to create tensor desc.");
806         return nullptr;
807     }
808 
809     NN_TensorDesc *tensorDesc = reinterpret_cast<NN_TensorDesc *>(tensorDescImpl);
810     return tensorDesc;
811 }
812 
OH_NNTensorDesc_Destroy(NN_TensorDesc ** tensorDesc)813 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_Destroy(NN_TensorDesc **tensorDesc)
814 {
815     if (tensorDesc == nullptr) {
816         LOGE("OH_NNTensorDesc_Destroy failed, tensorDesc is nullptr.");
817         return OH_NN_INVALID_PARAMETER;
818     }
819     if (*tensorDesc == nullptr) {
820         LOGE("OH_NNTensorDesc_Destroy failed, *tensorDesc is nullptr.");
821         return OH_NN_INVALID_PARAMETER;
822     }
823 
824     TensorDesc *tensorDescImpl = reinterpret_cast<TensorDesc *>(*tensorDesc);
825     delete tensorDescImpl;
826     *tensorDesc = nullptr;
827     return OH_NN_SUCCESS;
828 }
829 
OH_NNTensorDesc_SetName(NN_TensorDesc * tensorDesc,const char * name)830 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_SetName(NN_TensorDesc *tensorDesc, const char *name)
831 {
832     if (tensorDesc == nullptr) {
833         LOGE("OH_NNTensorDesc_SetName failed, tensorDesc is nullptr.");
834         return OH_NN_INVALID_PARAMETER;
835     }
836     if (name == nullptr) {
837         LOGE("OH_NNTensorDesc_SetName failed, name is nullptr.");
838         return OH_NN_INVALID_PARAMETER;
839     }
840 
841     TensorDesc *tensorDescImpl = reinterpret_cast<TensorDesc *>(tensorDesc);
842     return tensorDescImpl->SetName(name);
843 }
844 
OH_NNTensorDesc_GetName(const NN_TensorDesc * tensorDesc,const char ** name)845 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_GetName(const NN_TensorDesc *tensorDesc, const char **name)
846 {
847     if (tensorDesc == nullptr) {
848         LOGE("OH_NNTensorDesc_GetName failed, tensorDesc is nullptr.");
849         return OH_NN_INVALID_PARAMETER;
850     }
851     if (name == nullptr) {
852         LOGE("OH_NNTensorDesc_GetName failed, name is nullptr.");
853         return OH_NN_INVALID_PARAMETER;
854     }
855     if (*name != nullptr) {
856         LOGE("OH_NNTensorDesc_GetName failed, *name is not nullptr.");
857         return OH_NN_INVALID_PARAMETER;
858     }
859 
860     const TensorDesc *tensorDescImpl = reinterpret_cast<const TensorDesc *>(tensorDesc);
861     return tensorDescImpl->GetName(name);
862 }
863 
OH_NNTensorDesc_SetDataType(NN_TensorDesc * tensorDesc,OH_NN_DataType dataType)864 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_SetDataType(NN_TensorDesc *tensorDesc, OH_NN_DataType dataType)
865 {
866     if (tensorDesc == nullptr) {
867         LOGE("OH_NNTensorDesc_SetDataType failed, tensorDesc is nullptr.");
868         return OH_NN_INVALID_PARAMETER;
869     }
870 
871     TensorDesc *tensorDescImpl = reinterpret_cast<TensorDesc *>(tensorDesc);
872     return tensorDescImpl->SetDataType(dataType);
873 }
874 
OH_NNTensorDesc_GetDataType(const NN_TensorDesc * tensorDesc,OH_NN_DataType * dataType)875 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_GetDataType(const NN_TensorDesc *tensorDesc, OH_NN_DataType *dataType)
876 {
877     if (tensorDesc == nullptr) {
878         LOGE("OH_NNTensorDesc_GetDataType failed, tensorDesc is nullptr.");
879         return OH_NN_INVALID_PARAMETER;
880     }
881     if (dataType == nullptr) {
882         LOGE("OH_NNTensorDesc_GetDataType failed, dataType is nullptr.");
883         return OH_NN_INVALID_PARAMETER;
884     }
885 
886     const TensorDesc *tensorDescImpl = reinterpret_cast<const TensorDesc *>(tensorDesc);
887     return tensorDescImpl->GetDataType(dataType);
888 }
889 
OH_NNTensorDesc_SetShape(NN_TensorDesc * tensorDesc,const int32_t * shape,size_t shapeLength)890 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_SetShape(NN_TensorDesc *tensorDesc, const int32_t *shape, size_t shapeLength)
891 {
892     if (tensorDesc == nullptr) {
893         LOGE("OH_NNTensorDesc_SetShape failed, tensorDesc is nullptr.");
894         return OH_NN_INVALID_PARAMETER;
895     }
896     if (shape == nullptr) {
897         LOGE("OH_NNTensorDesc_SetShape failed, shape is nullptr.");
898         return OH_NN_INVALID_PARAMETER;
899     }
900     if (shapeLength == 0) {
901         LOGE("OH_NNTensorDesc_SetShape failed, shapeLength is 0.");
902         return OH_NN_INVALID_PARAMETER;
903     }
904     TensorDesc *tensorDescImpl = reinterpret_cast<TensorDesc *>(tensorDesc);
905     return tensorDescImpl->SetShape(shape, shapeLength);
906 }
907 
OH_NNTensorDesc_GetShape(const NN_TensorDesc * tensorDesc,int32_t ** shape,size_t * shapeLength)908 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_GetShape(const NN_TensorDesc *tensorDesc,
909                                                    int32_t **shape,
910                                                    size_t *shapeLength)
911 {
912     if (tensorDesc == nullptr) {
913         LOGE("OH_NNTensorDesc_GetShape failed, tensorDesc is nullptr.");
914         return OH_NN_INVALID_PARAMETER;
915     }
916     if (shape == nullptr) {
917         LOGE("OH_NNTensorDesc_GetShape failed, shape is nullptr.");
918         return OH_NN_INVALID_PARAMETER;
919     }
920     if (*shape != nullptr) {
921         LOGE("OH_NNTensorDesc_GetShape failed, *shape is not nullptr.");
922         return OH_NN_INVALID_PARAMETER;
923     }
924     if (shapeLength == nullptr) {
925         LOGE("OH_NNTensorDesc_GetShape failed, shapeLength is nullptr.");
926         return OH_NN_INVALID_PARAMETER;
927     }
928 
929     const TensorDesc *tensorDescImpl = reinterpret_cast<const TensorDesc *>(tensorDesc);
930     return tensorDescImpl->GetShape(shape, shapeLength);
931 }
932 
OH_NNTensorDesc_SetFormat(NN_TensorDesc * tensorDesc,OH_NN_Format format)933 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_SetFormat(NN_TensorDesc *tensorDesc, OH_NN_Format format)
934 {
935     if (tensorDesc == nullptr) {
936         LOGE("OH_NNTensorDesc_SetFormat failed, tensorDesc is nullptr.");
937         return OH_NN_INVALID_PARAMETER;
938     }
939 
940     TensorDesc *tensorDescImpl = reinterpret_cast<TensorDesc *>(tensorDesc);
941     return tensorDescImpl->SetFormat(format);
942 }
943 
OH_NNTensorDesc_GetFormat(const NN_TensorDesc * tensorDesc,OH_NN_Format * format)944 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_GetFormat(const NN_TensorDesc *tensorDesc, OH_NN_Format *format)
945 {
946     if (tensorDesc == nullptr) {
947         LOGE("OH_NNTensorDesc_GetFormat failed, tensorDesc is nullptr.");
948         return OH_NN_INVALID_PARAMETER;
949     }
950     if (format == nullptr) {
951         LOGE("OH_NNTensorDesc_GetFormat failed, format is nullptr.");
952         return OH_NN_INVALID_PARAMETER;
953     }
954 
955     const TensorDesc *tensorDescImpl = reinterpret_cast<const TensorDesc *>(tensorDesc);
956     return tensorDescImpl->GetFormat(format);
957 }
958 
OH_NNTensorDesc_GetElementCount(const NN_TensorDesc * tensorDesc,size_t * elementCount)959 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_GetElementCount(const NN_TensorDesc *tensorDesc, size_t *elementCount)
960 {
961     if (tensorDesc == nullptr) {
962         LOGE("OH_NNTensorDesc_GetElementCount failed, tensorDesc is nullptr.");
963         return OH_NN_INVALID_PARAMETER;
964     }
965     if (elementCount == nullptr) {
966         LOGE("OH_NNTensorDesc_GetElementCount failed, elementCount is nullptr.");
967         return OH_NN_INVALID_PARAMETER;
968     }
969 
970     const TensorDesc *tensorDescImpl = reinterpret_cast<const TensorDesc *>(tensorDesc);
971     return tensorDescImpl->GetElementNum(elementCount);
972 }
973 
OH_NNTensorDesc_GetByteSize(const NN_TensorDesc * tensorDesc,size_t * byteSize)974 NNRT_API OH_NN_ReturnCode OH_NNTensorDesc_GetByteSize(const NN_TensorDesc *tensorDesc, size_t *byteSize)
975 {
976     if (tensorDesc == nullptr) {
977         LOGE("OH_NNTensorDesc_GetByteSize failed, tensorDesc is nullptr.");
978         return OH_NN_INVALID_PARAMETER;
979     }
980     if (byteSize == nullptr) {
981         LOGE("OH_NNTensorDesc_GetByteSize failed, byteSize is nullptr.");
982         return OH_NN_INVALID_PARAMETER;
983     }
984 
985     const TensorDesc *tensorDescImpl = reinterpret_cast<const TensorDesc *>(tensorDesc);
986     return tensorDescImpl->GetByteSize(byteSize);
987 }
988 
OH_NNTensor_Create(size_t deviceID,NN_TensorDesc * tensorDesc)989 NNRT_API NN_Tensor* OH_NNTensor_Create(size_t deviceID, NN_TensorDesc *tensorDesc)
990 {
991     if (tensorDesc == nullptr) {
992         LOGE("OH_NNTensor_Create failed, tensorDesc is nullptr.");
993         return nullptr;
994     }
995 
996     BackendManager& backendManager = BackendManager::GetInstance();
997     std::shared_ptr<Backend> backend = backendManager.GetBackend(deviceID);
998     if (backend == nullptr) {
999         LOGE("OH_NNTensor_Create failed, passed invalid backend name.");
1000         return nullptr;
1001     }
1002 
1003     TensorDesc* descImpl = reinterpret_cast<TensorDesc*>(tensorDesc);
1004     Tensor* tensorImpl = backend->CreateTensor(descImpl);
1005     if (tensorImpl == nullptr) {
1006         LOGE("OH_NNTensor_Create failed, failed to create tensor.");
1007         return nullptr;
1008     }
1009 
1010     OH_NN_ReturnCode ret = tensorImpl->CreateData();
1011     if (ret != OH_NN_SUCCESS) {
1012         LOGE("OH_NNTensor_Create failed, failed to create tensor.");
1013         backend->DestroyTensor(tensorImpl);
1014         return nullptr;
1015     }
1016 
1017     NN_Tensor* tensor = reinterpret_cast<NN_Tensor*>(tensorImpl);
1018     return tensor;
1019 }
1020 
OH_NNTensor_CreateWithSize(size_t deviceID,NN_TensorDesc * tensorDesc,size_t size)1021 NNRT_API NN_Tensor* OH_NNTensor_CreateWithSize(size_t deviceID, NN_TensorDesc *tensorDesc, size_t size)
1022 {
1023     if (tensorDesc == nullptr) {
1024         LOGE("OH_NNTensor_CreateWithSize failed, tensorDesc is nullptr.");
1025         return nullptr;
1026     }
1027 
1028     BackendManager& backendManager = BackendManager::GetInstance();
1029     std::shared_ptr<Backend> backend = backendManager.GetBackend(deviceID);
1030     if (backend == nullptr) {
1031         LOGE("OH_NNTensor_CreateWithSize failed, passed invalid backend name.");
1032         return nullptr;
1033     }
1034 
1035     TensorDesc* descImpl = reinterpret_cast<TensorDesc*>(tensorDesc);
1036     Tensor* tensorImpl = backend->CreateTensor(descImpl);
1037     if (tensorImpl == nullptr) {
1038         LOGE("OH_NNTensor_CreateWithSize failed, failed to create tensor.");
1039         return nullptr;
1040     }
1041 
1042     OH_NN_ReturnCode ret = tensorImpl->CreateData(size);
1043     if (ret != OH_NN_SUCCESS) {
1044         LOGE("OH_NNTensor_CreateWithSize failed, failed to create tensor.");
1045         backend->DestroyTensor(tensorImpl);
1046         return nullptr;
1047     }
1048 
1049     NN_Tensor* tensor = reinterpret_cast<NN_Tensor*>(tensorImpl);
1050     return tensor;
1051 }
1052 
OH_NNTensor_CreateWithFd(size_t deviceID,NN_TensorDesc * tensorDesc,int fd,size_t size,size_t offset)1053 NNRT_API NN_Tensor* OH_NNTensor_CreateWithFd(size_t deviceID,
1054                                              NN_TensorDesc *tensorDesc,
1055                                              int fd, size_t size,
1056                                              size_t offset)
1057 {
1058     if (tensorDesc == nullptr) {
1059         LOGE("OH_NNTensor_CreateWithFd failed, tensorDesc is nullptr.");
1060         return nullptr;
1061     }
1062     if (fd < 0) {
1063         LOGE("OH_NNTensor_CreateWithFd failed, fd is less than zero.");
1064         return nullptr;
1065     }
1066     if (size == 0) {
1067         LOGE("OH_NNTensor_CreateWithFd failed, size is zero.");
1068         return nullptr;
1069     }
1070     if (size < offset) {
1071         LOGE("OH_NNTensor_CreateWithFd failed, size is smaller than offset.");
1072         return nullptr;
1073     }
1074     TensorDesc* descImpl = reinterpret_cast<TensorDesc*>(tensorDesc);
1075     size_t byteSize = 0;
1076     auto ret = descImpl->GetByteSize(&byteSize);
1077     if (ret != OH_NN_SUCCESS) {
1078         LOGE("NNTensor2_0::CreateData failed, failed to get byte size from tensorDesc.");
1079         return nullptr;
1080     }
1081     if ((size - offset) < byteSize) {
1082         LOGE("OH_NNTensor_CreateWithFd failed, size of fd is insufficient.");
1083         return nullptr;
1084     }
1085 
1086     BackendManager& backendManager = BackendManager::GetInstance();
1087     std::shared_ptr<Backend> backend = backendManager.GetBackend(deviceID);
1088     if (backend == nullptr) {
1089         LOGE("OH_NNTensor_CreateWithFd failed, passed invalid backend name.");
1090         return nullptr;
1091     }
1092 
1093     Tensor* tensorImpl = backend->CreateTensor(descImpl);
1094     if (tensorImpl == nullptr) {
1095         LOGE("OH_NNTensor_CreateWithFd failed, failed to create tensor.");
1096         return nullptr;
1097     }
1098 
1099     ret = tensorImpl->CreateData(fd, size, offset);
1100     if (ret != OH_NN_SUCCESS) {
1101         LOGE("OH_NNTensor_CreateWithFd failed, failed to create tensor.");
1102         backend->DestroyTensor(tensorImpl);
1103         return nullptr;
1104     }
1105 
1106     NN_Tensor* tensor = reinterpret_cast<NN_Tensor*>(tensorImpl);
1107     return tensor;
1108 }
1109 
OH_NNTensor_Destroy(NN_Tensor ** tensor)1110 NNRT_API OH_NN_ReturnCode OH_NNTensor_Destroy(NN_Tensor **tensor)
1111 {
1112     if (tensor == nullptr) {
1113         LOGE("OH_NNTensor_Destroy failed, tensor is nullptr.");
1114         return OH_NN_INVALID_PARAMETER;
1115     }
1116     if (*tensor == nullptr) {
1117         LOGE("OH_NNTensor_Destroy failed, *tensor is nullptr.");
1118         return OH_NN_INVALID_PARAMETER;
1119     }
1120 
1121     Tensor* tensorImpl = reinterpret_cast<Tensor*>(*tensor);
1122     size_t backendID = tensorImpl->GetBackendID();
1123     BackendManager& backendManager = BackendManager::GetInstance();
1124     std::shared_ptr<Backend> backend = backendManager.GetBackend(backendID);
1125     if (backend == nullptr) {
1126         LOGE("OH_NNTensor_Destroy failed, passed invalid backend name %{public}zu.", backendID);
1127         return OH_NN_NULL_PTR;
1128     }
1129 
1130     auto ret = backend->DestroyTensor(tensorImpl);
1131     if (ret != OH_NN_SUCCESS) {
1132         LOGE("OH_NNTensor_Destroy failed, failed to destroy tensor.");
1133         return ret;
1134     }
1135     *tensor = nullptr;
1136     return OH_NN_SUCCESS;
1137 }
1138 
OH_NNTensor_GetTensorDesc(const NN_Tensor * tensor)1139 NNRT_API NN_TensorDesc* OH_NNTensor_GetTensorDesc(const NN_Tensor *tensor)
1140 {
1141     if (tensor == nullptr) {
1142         LOGE("OH_NNTensor_GetTensorDesc failed, tensor is nullptr.");
1143         return nullptr;
1144     }
1145 
1146     const Tensor *tensorImpl = reinterpret_cast<const Tensor *>(tensor);
1147     auto tensorDescImpl = tensorImpl->GetTensorDesc();
1148     if (tensorDescImpl == nullptr) {
1149         LOGE("OH_NNTensor_GetTensorDesc failed, tensor desc is nullptr.");
1150         return nullptr;
1151     }
1152 
1153     NN_TensorDesc *tensorDesc = reinterpret_cast<NN_TensorDesc *>(tensorDescImpl);
1154     return tensorDesc;
1155 }
1156 
OH_NNTensor_GetDataBuffer(const NN_Tensor * tensor)1157 NNRT_API void* OH_NNTensor_GetDataBuffer(const NN_Tensor *tensor)
1158 {
1159     if (tensor == nullptr) {
1160         LOGE("OH_NNTensor_GetDataBuffer failed, tensor is nullptr.");
1161         return nullptr;
1162     }
1163 
1164     const Tensor *tensorImpl = reinterpret_cast<const Tensor *>(tensor);
1165     auto data = tensorImpl->GetData();
1166     if (data == nullptr) {
1167         LOGE("OH_NNTensor_GetDataBuffer failed, data is nullptr.");
1168         return nullptr;
1169     }
1170 
1171     return data;
1172 }
1173 
OH_NNTensor_GetSize(const NN_Tensor * tensor,size_t * size)1174 NNRT_API OH_NN_ReturnCode OH_NNTensor_GetSize(const NN_Tensor *tensor, size_t *size)
1175 {
1176     if (tensor == nullptr) {
1177         LOGE("OH_NNTensor_GetSize failed, tensor is nullptr.");
1178         return OH_NN_INVALID_PARAMETER;
1179     }
1180     if (size == nullptr) {
1181         LOGE("OH_NNTensor_GetSize failed, size is nullptr.");
1182         return OH_NN_INVALID_PARAMETER;
1183     }
1184 
1185     const Tensor *tensorImpl = reinterpret_cast<const Tensor *>(tensor);
1186     *size = tensorImpl->GetSize();
1187     return OH_NN_SUCCESS;
1188 }
1189 
OH_NNTensor_GetFd(const NN_Tensor * tensor,int * fd)1190 NNRT_API OH_NN_ReturnCode OH_NNTensor_GetFd(const NN_Tensor *tensor, int *fd)
1191 {
1192     if (tensor == nullptr) {
1193         LOGE("OH_NNTensor_GetFd failed, tensor is nullptr.");
1194         return OH_NN_INVALID_PARAMETER;
1195     }
1196     if (fd == nullptr) {
1197         LOGE("OH_NNTensor_GetFd failed, fd is nullptr.");
1198         return OH_NN_INVALID_PARAMETER;
1199     }
1200 
1201     const Tensor *tensorImpl = reinterpret_cast<const Tensor *>(tensor);
1202     *fd = tensorImpl->GetFd();
1203     return OH_NN_SUCCESS;
1204 }
1205 
OH_NNTensor_GetOffset(const NN_Tensor * tensor,size_t * offset)1206 NNRT_API OH_NN_ReturnCode OH_NNTensor_GetOffset(const NN_Tensor *tensor, size_t *offset)
1207 {
1208     if (tensor == nullptr) {
1209         LOGE("OH_NNTensor_GetOffset failed, tensor is nullptr.");
1210         return OH_NN_INVALID_PARAMETER;
1211     }
1212     if (offset == nullptr) {
1213         LOGE("OH_NNTensor_GetOffset failed, offset is nullptr.");
1214         return OH_NN_INVALID_PARAMETER;
1215     }
1216 
1217     const Tensor *tensorImpl = reinterpret_cast<const Tensor *>(tensor);
1218     *offset = tensorImpl->GetOffset();
1219     return OH_NN_SUCCESS;
1220 }
1221 
Scheduling(Compilation ** compilation)1222 OH_NN_ReturnCode Scheduling(Compilation** compilation)
1223 {
1224     if (compilation == nullptr) {
1225         LOGE("Scheduling failed, compilation is nullptr.");
1226         return OH_NN_INVALID_PARAMETER;
1227     }
1228 
1229     Compilation* compilationImpl = *compilation;
1230     if (compilationImpl == nullptr) {
1231         LOGE("Scheduling failed, compilation implementation is nullptr.");
1232         return OH_NN_INVALID_PARAMETER;
1233     }
1234 
1235     NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
1236     if (!nnrtService.IsServiceAvaliable()) {
1237         LOGW("Scheduling failed, fail to get nnrt service, skip schedule.");
1238         return OH_NN_SUCCESS;
1239     }
1240 
1241     if (nnrtService.IsSupportScheduling == nullptr) {
1242         LOGE("Scheduling failed, nnrtService IsSupportScheduling func is nullptr.");
1243         return OH_NN_INVALID_PARAMETER;
1244     }
1245 
1246     std::string cachePath = "";
1247     if (compilationImpl->cachePath != nullptr) {
1248         cachePath = compilationImpl->cachePath;
1249     }
1250 
1251     bool supportStat = false;
1252     int ret = nnrtService.IsSupportScheduling(&supportStat);
1253     if (ret != static_cast<int>(OH_NN_SUCCESS)) {
1254         LOGE("Scheduling failed, some error happened when judge if support scheduling.");
1255         return static_cast<OH_NN_ReturnCode>(ret);
1256     }
1257     if (!supportStat) {
1258         LOGW("device not support scheduling, jumper over scheduling.");
1259         return OH_NN_SUCCESS;
1260     }
1261 
1262     if (nnrtService.Scheduling == nullptr) {
1263         LOGE("Scheduling failed, nnrtService IsSupportScheduling func is nullptr.");
1264         return OH_NN_INVALID_PARAMETER;
1265     }
1266 
1267     bool needModelLatency = false;
1268     ret = nnrtService.Scheduling(compilationImpl->hiaiModelId, &needModelLatency, cachePath.c_str());
1269     if (ret != static_cast<int>(OH_NN_SUCCESS)) {
1270         LOGE("Scheduling failed, some error happened when scheduling.");
1271         return static_cast<OH_NN_ReturnCode>(ret);
1272     }
1273 
1274     compilationImpl->isNeedModelLatency = needModelLatency;
1275 
1276     LOGI("Scheduling success.");
1277     return OH_NN_SUCCESS;
1278 }
1279 
SetModelId(const Compilation * compilation)1280 OH_NN_ReturnCode SetModelId(const Compilation* compilation)
1281 {
1282     if (compilation == nullptr) {
1283         LOGE("SetModelId failed, compilation is nullptr.");
1284         return OH_NN_INVALID_PARAMETER;
1285     }
1286 
1287     NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
1288     if (!nnrtService.IsServiceAvaliable()) {
1289         LOGW("SetModelId failed, fail to get nnrt service, skip set modelId.");
1290         return OH_NN_SUCCESS;
1291     }
1292 
1293     if (nnrtService.SetModelID == nullptr) {
1294         LOGE("SetModelId failed, nnrtService SetModelID func is nullptr.");
1295         return OH_NN_INVALID_PARAMETER;
1296     }
1297 
1298     int ret = nnrtService.SetModelID(
1299         compilation->callingPid, compilation->hiaiModelId, compilation->nnrtModelID);
1300     if (ret != static_cast<int>(OH_NN_SUCCESS)) {
1301         LOGE("SetModelId failed, fail to set modelId.");
1302         return static_cast<OH_NN_ReturnCode>(ret);
1303     }
1304 
1305     return OH_NN_SUCCESS;
1306 }
1307 
ExecutorPrepare(Executor ** executor,Compilation ** compilation)1308 OH_NN_ReturnCode ExecutorPrepare(Executor** executor, Compilation** compilation)
1309 {
1310     if (executor == nullptr) {
1311         LOGE("ExecutorPrepare failed, executor is nullptr.");
1312         return OH_NN_INVALID_PARAMETER;
1313     }
1314 
1315     if (compilation == nullptr) {
1316         LOGE("ExecutorPrepare failed, compilation is nullptr.");
1317         return OH_NN_INVALID_PARAMETER;
1318     }
1319 
1320     Executor* executorImpl = *executor;
1321     if (executorImpl == nullptr) {
1322         LOGE("ExecutorPrepare failed, executor implementation is nullptr.");
1323         return OH_NN_INVALID_PARAMETER;
1324     }
1325 
1326     Compilation* compilationImpl = *compilation;
1327     if (compilationImpl == nullptr) {
1328         LOGE("ExecutorPrepare failed, compilation implementation is nullptr.");
1329         return OH_NN_INVALID_PARAMETER;
1330     }
1331 
1332     OH_NN_ReturnCode ret = SetModelId(compilationImpl);
1333     if (ret != OH_NN_SUCCESS) {
1334         LOGE("ExecutorPrepare failed, fail to set modelId.");
1335         return ret;
1336     }
1337 
1338     LOGD("ExecutorPrepare parameter, callingPid: %{public}d, hiaiModelId: %{public}u, nnrtModelId: %{public}zu.",
1339          compilationImpl->callingPid, compilationImpl->hiaiModelId, compilationImpl->nnrtModelID);
1340 
1341     ret = Scheduling(&compilationImpl);
1342     if (ret != OH_NN_SUCCESS) {
1343         LOGE("ExecutorPrepare failed, failed to create executor.");
1344         return ret;
1345     }
1346 
1347     std::unordered_map<std::string, std::vector<char>> configMap;
1348     std::string callingPidStr = std::to_string(compilationImpl->callingPid);
1349     std::vector<char> vecCallingPid(callingPidStr.begin(), callingPidStr.end());
1350     vecCallingPid.emplace_back('\0');
1351     configMap["callingPid"] = vecCallingPid;
1352 
1353     std::string hiaiModelIdStr = std::to_string(compilationImpl->hiaiModelId);
1354     std::vector<char> vechiaiModelId(hiaiModelIdStr.begin(), hiaiModelIdStr.end());
1355     vechiaiModelId.emplace_back('\0');
1356     configMap["hiaiModelId"] = vechiaiModelId;
1357 
1358     std::vector<char> vecNeedLatency = { static_cast<char>(compilationImpl->isNeedModelLatency) };
1359     configMap["isNeedModelLatency"] = vecNeedLatency;
1360 
1361     executorImpl->SetExtensionConfig(configMap);
1362     if (ret != OH_NN_SUCCESS) {
1363         LOGE("ExecutorPrepare failed, failed to set config to executor.");
1364         return ret;
1365     }
1366 
1367     return OH_NN_SUCCESS;
1368 }
1369 
OH_NNExecutor_Construct(OH_NNCompilation * compilation)1370 NNRT_API OH_NNExecutor *OH_NNExecutor_Construct(OH_NNCompilation *compilation)
1371 {
1372     if (compilation == nullptr) {
1373         LOGE("OH_NNExecutor_Construct failed, compilation is nullptr.");
1374         return nullptr;
1375     }
1376 
1377     Compilation *compilationImpl = reinterpret_cast<Compilation *>(compilation);
1378     BackendManager& backendManager = BackendManager::GetInstance();
1379     std::shared_ptr<Backend> backend = backendManager.GetBackend(compilationImpl->backendID);
1380     if (backend == nullptr) {
1381         LOGE("OH_NNExecutor_Construct failed, failed to get backend of %{public}zu.", compilationImpl->backendID);
1382         return nullptr;
1383     }
1384 
1385     Executor* executorImpl = backend->CreateExecutor(compilationImpl);
1386     if (executorImpl == nullptr) {
1387         LOGE("OH_NNExecutor_Construct failed, failed to create executor.");
1388         return nullptr;
1389     }
1390 
1391     OH_NN_ReturnCode ret = executorImpl->GetModelID(compilationImpl->hiaiModelId);
1392     if (ret != OH_NN_SUCCESS) {
1393         LOGE("OH_NNExecutor_Construct failed, failed to get hiai modelId.");
1394         OH_NNExecutor_Destroy(reinterpret_cast<OH_NNExecutor **>(&executorImpl));
1395         return nullptr;
1396     }
1397 
1398     ret = ExecutorPrepare(&executorImpl, &compilationImpl);
1399     if (ret != OH_NN_SUCCESS) {
1400         LOGE("OH_NNExecutor_Construct failed, failed to prepare executor.");
1401         OH_NNExecutor_Destroy(reinterpret_cast<OH_NNExecutor **>(&executorImpl));
1402         return nullptr;
1403     }
1404 
1405     OH_NNExecutor *executor = reinterpret_cast<OH_NNExecutor *>(executorImpl);
1406     return executor;
1407 }
1408 
Unload(const ExecutorConfig * config)1409 OH_NN_ReturnCode Unload(const ExecutorConfig* config)
1410 {
1411     if (config == nullptr) {
1412         LOGE("Unload failed, config is nullptr.");
1413         return OH_NN_INVALID_PARAMETER;
1414     }
1415 
1416     NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
1417     if (!nnrtService.IsServiceAvaliable()) {
1418         LOGW("Unload failed, fail to get nnrt service, skip unload.");
1419         return OH_NN_SUCCESS;
1420     }
1421 
1422     if (nnrtService.Unload == nullptr) {
1423         LOGE("Unload failed, nnrtService Unload func is nullptr.");
1424         return OH_NN_INVALID_PARAMETER;
1425     }
1426 
1427     int ret = nnrtService.Unload(config->hiaiModelId);
1428     if (ret != static_cast<int>(OH_NN_SUCCESS)) {
1429         LOGE("Unload failed, some error happen when unload hiaiModelId.");
1430         return static_cast<OH_NN_ReturnCode>(ret);
1431     }
1432 
1433     LOGI("Unload success.");
1434     return OH_NN_SUCCESS;
1435 }
1436 
OH_NNExecutor_Destroy(OH_NNExecutor ** executor)1437 NNRT_API void OH_NNExecutor_Destroy(OH_NNExecutor **executor)
1438 {
1439     if (executor == nullptr) {
1440         LOGE("OH_NNExecutor_Destroy failed, executor is nullptr.");
1441         return;
1442     }
1443     if (*executor == nullptr) {
1444         LOGE("OH_NNExecutor_Destroy failed, *executor is nullptr.");
1445         return;
1446     }
1447 
1448     Executor *executorImpl = reinterpret_cast<Executor *>(*executor);
1449     size_t backendID = executorImpl->GetBackendID();
1450     BackendManager& backendManager = BackendManager::GetInstance();
1451     std::shared_ptr<Backend> backend = backendManager.GetBackend(backendID);
1452     if (backend == nullptr) {
1453         LOGE("OH_NNExecutor_Destroy failed, failed to get backend of %{public}zu.", backendID);
1454         return;
1455     }
1456 
1457     OH_NN_ReturnCode ret = Unload(executorImpl->GetExecutorConfig());
1458     if (ret != OH_NN_SUCCESS) {
1459         LOGE("Unload failed, some error happened when unload nnrt service.");
1460     }
1461 
1462     auto returnCode = backend->DestroyExecutor(executorImpl);
1463     if (returnCode != OH_NN_SUCCESS) {
1464         LOGE("OH_NNExecutor_Destroy failed, failed to destroy executor.");
1465         return;
1466     }
1467 
1468     *executor = nullptr;
1469 }
1470 
OH_NNExecutor_GetOutputShape(OH_NNExecutor * executor,uint32_t outputIndex,int32_t ** shape,uint32_t * shapeLength)1471 NNRT_API OH_NN_ReturnCode OH_NNExecutor_GetOutputShape(OH_NNExecutor *executor,
1472                                                        uint32_t outputIndex,
1473                                                        int32_t **shape,
1474                                                        uint32_t *shapeLength)
1475 {
1476     if (executor == nullptr) {
1477         LOGE("OH_NNExecutor_GetOutputShape failed, executor is nullptr.");
1478         return OH_NN_INVALID_PARAMETER;
1479     }
1480     if (shape == nullptr) {
1481         LOGE("OH_NNExecutor_GetOutputShape failed, shape is nullptr.");
1482         return OH_NN_INVALID_PARAMETER;
1483     }
1484     if (*shape != nullptr) {
1485         LOGE("OH_NNExecutor_GetOutputShape failed, *shape is not nullptr.");
1486         return OH_NN_INVALID_PARAMETER;
1487     }
1488     if (shapeLength == nullptr) {
1489         LOGE("OH_NNExecutor_GetOutputShape failed, shapeLength is nullptr.");
1490         return OH_NN_INVALID_PARAMETER;
1491     }
1492 
1493     Executor *executorImpl = reinterpret_cast<Executor *>(executor);
1494     return executorImpl->GetOutputShape(outputIndex, shape, shapeLength);
1495 }
1496 
OH_NNExecutor_GetInputCount(const OH_NNExecutor * executor,size_t * inputCount)1497 NNRT_API OH_NN_ReturnCode OH_NNExecutor_GetInputCount(const OH_NNExecutor *executor, size_t *inputCount)
1498 {
1499     if (executor == nullptr) {
1500         LOGE("OH_NNExecutor_GetInputCount failed, executor is nullptr.");
1501         return OH_NN_INVALID_PARAMETER;
1502     }
1503     if (inputCount == nullptr) {
1504         LOGE("OH_NNExecutor_GetInputCount failed, inputCount is nullptr.");
1505         return OH_NN_INVALID_PARAMETER;
1506     }
1507 
1508     const Executor *executorImpl = reinterpret_cast<const Executor *>(executor);
1509     *inputCount = executorImpl->GetInputNum();
1510     return OH_NN_SUCCESS;
1511 }
1512 
OH_NNExecutor_GetOutputCount(const OH_NNExecutor * executor,size_t * outputCount)1513 NNRT_API OH_NN_ReturnCode OH_NNExecutor_GetOutputCount(const OH_NNExecutor *executor, size_t *outputCount)
1514 {
1515     if (executor == nullptr) {
1516         LOGE("OH_NNExecutor_GetOutputCount failed, executor is nullptr.");
1517         return OH_NN_INVALID_PARAMETER;
1518     }
1519     if (outputCount == nullptr) {
1520         LOGE("OH_NNExecutor_GetOutputCount failed, outputCount is nullptr.");
1521         return OH_NN_INVALID_PARAMETER;
1522     }
1523 
1524     const Executor *executorImpl = reinterpret_cast<const Executor *>(executor);
1525     *outputCount = executorImpl->GetOutputNum();
1526     return OH_NN_SUCCESS;
1527 }
1528 
OH_NNExecutor_CreateInputTensorDesc(const OH_NNExecutor * executor,size_t index)1529 NNRT_API NN_TensorDesc* OH_NNExecutor_CreateInputTensorDesc(const OH_NNExecutor *executor, size_t index)
1530 {
1531     if (executor == nullptr) {
1532         LOGE("OH_NNExecutor_CreateInputTensorDesc failed, executor is nullptr.");
1533         return nullptr;
1534     }
1535 
1536     const Executor *executorImpl = reinterpret_cast<const Executor *>(executor);
1537     return executorImpl->CreateInputTensorDesc(index);
1538 }
1539 
OH_NNExecutor_CreateOutputTensorDesc(const OH_NNExecutor * executor,size_t index)1540 NNRT_API NN_TensorDesc* OH_NNExecutor_CreateOutputTensorDesc(const OH_NNExecutor *executor, size_t index)
1541 {
1542     if (executor == nullptr) {
1543         LOGE("OH_NNExecutor_CreateOutputTensorDesc failed, executor is nullptr.");
1544         return nullptr;
1545     }
1546 
1547     const Executor *executorImpl = reinterpret_cast<const Executor *>(executor);
1548     return executorImpl->CreateOutputTensorDesc(index);
1549 }
1550 
OH_NNExecutor_GetInputDimRange(const OH_NNExecutor * executor,size_t index,size_t ** minInputDims,size_t ** maxInputDims,size_t * shapeLength)1551 NNRT_API OH_NN_ReturnCode OH_NNExecutor_GetInputDimRange(const OH_NNExecutor *executor,
1552                                                          size_t index,
1553                                                          size_t **minInputDims,
1554                                                          size_t **maxInputDims,
1555                                                          size_t *shapeLength)
1556 {
1557     if (executor == nullptr) {
1558         LOGE("OH_NNExecutor_GetInputDimRange failed, executor is nullptr.");
1559         return OH_NN_INVALID_PARAMETER;
1560     }
1561     if (minInputDims == nullptr) {
1562         LOGE("OH_NNExecutor_GetInputDimRange failed, minInputDims is nullptr.");
1563         return OH_NN_INVALID_PARAMETER;
1564     }
1565     if (maxInputDims == nullptr) {
1566         LOGE("OH_NNExecutor_GetInputDimRange failed, maxInputDims is nullptr.");
1567         return OH_NN_INVALID_PARAMETER;
1568     }
1569     if (shapeLength == nullptr) {
1570         LOGE("OH_NNExecutor_GetInputDimRange failed, shapeLength is nullptr.");
1571         return OH_NN_INVALID_PARAMETER;
1572     }
1573 
1574     const Executor *executorImpl = reinterpret_cast<const Executor *>(executor);
1575     return executorImpl->GetInputDimRange(index, minInputDims, maxInputDims, shapeLength);
1576 }
1577 
OH_NNExecutor_SetOnRunDone(OH_NNExecutor * executor,NN_OnRunDone onRunDone)1578 NNRT_API OH_NN_ReturnCode OH_NNExecutor_SetOnRunDone(OH_NNExecutor *executor, NN_OnRunDone onRunDone)
1579 {
1580     if (executor == nullptr) {
1581         LOGE("OH_NNExecutor_SetOnRunDone failed, executor is nullptr.");
1582         return OH_NN_INVALID_PARAMETER;
1583     }
1584     if (onRunDone == nullptr) {
1585         LOGE("OH_NNExecutor_SetOnRunDone failed, onRunDone is nullptr.");
1586         return OH_NN_INVALID_PARAMETER;
1587     }
1588 
1589     Executor *executorImpl = reinterpret_cast<Executor *>(executor);
1590     return executorImpl->SetOnRunDone(onRunDone);
1591 }
1592 
OH_NNExecutor_SetOnServiceDied(OH_NNExecutor * executor,NN_OnServiceDied onServiceDied)1593 NNRT_API OH_NN_ReturnCode OH_NNExecutor_SetOnServiceDied(OH_NNExecutor *executor, NN_OnServiceDied onServiceDied)
1594 {
1595     if (executor == nullptr) {
1596         LOGE("OH_NNExecutor_SetOnServiceDied failed, executor is nullptr.");
1597         return OH_NN_INVALID_PARAMETER;
1598     }
1599     if (onServiceDied == nullptr) {
1600         LOGE("OH_NNExecutor_SetOnServiceDied failed, onServiceDied is nullptr.");
1601         return OH_NN_INVALID_PARAMETER;
1602     }
1603 
1604     Executor *executorImpl = reinterpret_cast<Executor *>(executor);
1605     return executorImpl->SetOnServiceDied(onServiceDied);
1606 }
1607 
UpdateModelLatency(const ExecutorConfig * config,int32_t modelLatency)1608 OH_NN_ReturnCode UpdateModelLatency(const ExecutorConfig* config, int32_t modelLatency)
1609 {
1610     if (config == nullptr) {
1611         LOGE("UpdateModelLatency failed, config is nullptr.");
1612         return OH_NN_INVALID_PARAMETER;
1613     }
1614 
1615     NNRtServiceApi& nnrtService = NNRtServiceApi::GetInstance();
1616     if (!nnrtService.IsServiceAvaliable()) {
1617         LOGW("UpdateModelLatency failed, fail to get nnrt service, skip update model latency.");
1618         return OH_NN_SUCCESS;
1619     }
1620 
1621     if (nnrtService.UpdateModelLatency == nullptr) {
1622         LOGE("UpdateModelLatency failed, nnrtService UpdateModelLatency func is nullptr.");
1623         return OH_NN_INVALID_PARAMETER;
1624     }
1625 
1626     LOGD("UpdateModelLatency, hiaiModelId: %{public}u, modelLatency: %{public}d.", config->hiaiModelId, modelLatency);
1627 
1628     int ret = nnrtService.UpdateModelLatency(config->hiaiModelId, modelLatency);
1629     if (ret != static_cast<int>(OH_NN_SUCCESS)) {
1630         LOGE("UpdateModelLatency failed, nnrtService is not exist, jump over UpdateModelLatency.");
1631         return static_cast<OH_NN_ReturnCode>(ret);
1632     }
1633 
1634     LOGI("UpdateModelLatency success.");
1635     return OH_NN_SUCCESS;
1636 }
1637 
RunSync(Executor * executor,NN_Tensor * inputTensor[],size_t inputCount,NN_Tensor * outputTensor[],size_t outputCount)1638 OH_NN_ReturnCode RunSync(Executor *executor,
1639                          NN_Tensor *inputTensor[],
1640                          size_t inputCount,
1641                          NN_Tensor *outputTensor[],
1642                          size_t outputCount)
1643 {
1644     ExecutorConfig* configPtr = executor->GetExecutorConfig();
1645     if (configPtr == nullptr) {
1646         LOGE("RunSync failed, executor config is nullptr.");
1647         return OH_NN_INVALID_PARAMETER;
1648     }
1649 
1650     long timeStart = 0;
1651     if (configPtr->isNeedModelLatency) {
1652         timeStart = std::chrono::duration_cast<std::chrono::milliseconds>(
1653             std::chrono::system_clock::now().time_since_epoch()).count();
1654     }
1655 
1656     OH_NN_ReturnCode ret = executor->RunSync(inputTensor, inputCount, outputTensor, outputCount);
1657     if (ret != OH_NN_SUCCESS) {
1658         LOGE("OH_NNExecutor_RunSync failed, fail to run executor.");
1659         return ret;
1660     }
1661 
1662     if (configPtr->isNeedModelLatency) {
1663         long timeEnd = std::chrono::duration_cast<std::chrono::milliseconds>(
1664             std::chrono::system_clock::now().time_since_epoch()).count();
1665         int32_t modelLatency = static_cast<int32_t>((timeEnd - timeStart));
1666         std::thread t(UpdateModelLatency, configPtr, modelLatency);
1667         t.detach();
1668         LOGE("update async start.");
1669 
1670         configPtr->isNeedModelLatency = false;
1671         std::unordered_map<std::string, std::vector<char>> configMap;
1672         std::vector<char> vecNeedLatency = { static_cast<char>(configPtr->isNeedModelLatency) };
1673         configMap["isNeedModelLatency"] = vecNeedLatency;
1674 
1675         ret = executor->SetExtensionConfig(configMap);
1676         if (ret != OH_NN_SUCCESS) {
1677             LOGE("OH_NNExecutor_RunSync failed, fail update executor config.");
1678             return ret;
1679         }
1680     }
1681 
1682     return OH_NN_SUCCESS;
1683 }
1684 
OH_NNExecutor_RunSync(OH_NNExecutor * executor,NN_Tensor * inputTensor[],size_t inputCount,NN_Tensor * outputTensor[],size_t outputCount)1685 NNRT_API OH_NN_ReturnCode OH_NNExecutor_RunSync(OH_NNExecutor *executor,
1686                                                 NN_Tensor *inputTensor[],
1687                                                 size_t inputCount,
1688                                                 NN_Tensor *outputTensor[],
1689                                                 size_t outputCount)
1690 {
1691     if (executor == nullptr) {
1692         LOGE("OH_NNExecutor_RunSync failed, executor is nullptr.");
1693         return OH_NN_INVALID_PARAMETER;
1694     }
1695     if (inputTensor == nullptr) {
1696         LOGE("OH_NNExecutor_RunSync failed, inputTensor is nullptr.");
1697         return OH_NN_INVALID_PARAMETER;
1698     }
1699     if (inputCount == 0) {
1700         LOGE("OH_NNExecutor_RunSync failed, inputCount is 0.");
1701         return OH_NN_INVALID_PARAMETER;
1702     }
1703     if (outputTensor == nullptr) {
1704         LOGE("OH_NNExecutor_RunSync failed, outputTensor is nullptr.");
1705         return OH_NN_INVALID_PARAMETER;
1706     }
1707     if (outputCount == 0) {
1708         LOGE("OH_NNExecutor_RunSync failed, outputCount is 0.");
1709         return OH_NN_INVALID_PARAMETER;
1710     }
1711 
1712     Executor *executorImpl = reinterpret_cast<Executor *>(executor);
1713     return RunSync(executorImpl, inputTensor, inputCount, outputTensor, outputCount);
1714 }
1715 
OH_NNExecutor_RunAsync(OH_NNExecutor * executor,NN_Tensor * inputTensor[],size_t inputCount,NN_Tensor * outputTensor[],size_t outputCount,int32_t timeout,void * userData)1716 NNRT_API OH_NN_ReturnCode OH_NNExecutor_RunAsync(OH_NNExecutor *executor,
1717                                                  NN_Tensor* inputTensor[],
1718                                                  size_t inputCount,
1719                                                  NN_Tensor* outputTensor[],
1720                                                  size_t outputCount,
1721                                                  int32_t timeout,
1722                                                  void* userData)
1723 {
1724     if (executor == nullptr) {
1725         LOGE("OH_NNExecutor_RunAsync failed, executor is nullptr.");
1726         return OH_NN_INVALID_PARAMETER;
1727     }
1728     if (inputTensor == nullptr) {
1729         LOGE("OH_NNExecutor_RunAsync failed, inputTensor is nullptr.");
1730         return OH_NN_INVALID_PARAMETER;
1731     }
1732     if (inputCount == 0) {
1733         LOGE("OH_NNExecutor_RunAsync failed, inputCount is 0.");
1734         return OH_NN_INVALID_PARAMETER;
1735     }
1736     if (outputTensor == nullptr) {
1737         LOGE("OH_NNExecutor_RunAsync failed, outputTensor is nullptr.");
1738         return OH_NN_INVALID_PARAMETER;
1739     }
1740     if (outputCount == 0) {
1741         LOGE("OH_NNExecutor_RunAsync failed, outputCount is 0.");
1742         return OH_NN_INVALID_PARAMETER;
1743     }
1744     if (userData == nullptr) {
1745         LOGE("OH_NNExecutor_RunAsync failed, userData is nullptr.");
1746         return OH_NN_INVALID_PARAMETER;
1747     }
1748 
1749     Executor *executorImpl = reinterpret_cast<Executor *>(executor);
1750     return executorImpl->RunAsync(inputTensor, inputCount, outputTensor, outputCount, timeout, userData);
1751 }
1752