1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "backend_manager.h"
17
18 #include <algorithm>
19 #include "cpp_type.h"
20
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
~BackendManager()23 BackendManager::~BackendManager()
24 {
25 m_backends.clear();
26 m_backendNames.clear();
27 m_backendIDs.clear();
28 m_backendIDGroup.clear();
29 }
30
GetInstance()31 BackendManager& BackendManager::GetInstance()
32 {
33 // if libneural_network_runtime.so loaded
34 if (dlopen("libneural_network_runtime.so", RTLD_NOLOAD) != nullptr) {
35 // if libneural_network_runtime_ext.so not loaded, try to dlopen it
36 if (dlopen("libneural_network_runtime_ext.so", RTLD_NOLOAD) == nullptr) {
37 LOGI("dlopen libneural_network_runtime_ext.so.");
38 void* libHandle = dlopen("libneural_network_runtime_ext.so", RTLD_NOW | RTLD_GLOBAL);
39 if (libHandle == nullptr) {
40 LOGW("Failed to dlopen libneural_network_runtime_ext.so.");
41 }
42 }
43 }
44 static BackendManager instance;
45 return instance;
46 }
47
GetAllBackendsID()48 const std::vector<size_t>& BackendManager::GetAllBackendsID()
49 {
50 const std::lock_guard<std::mutex> lock(m_mtx);
51 return m_backendIDs;
52 }
53
GetBackend(size_t backendID)54 std::shared_ptr<Backend> BackendManager::GetBackend(size_t backendID)
55 {
56 const std::lock_guard<std::mutex> lock(m_mtx);
57 if (m_backends.empty()) {
58 LOGE("[BackendManager] GetBackend failed, there is no registered backend can be used.");
59 return nullptr;
60 }
61
62 auto iter = m_backends.begin();
63 if (backendID == static_cast<size_t>(0)) {
64 LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
65 return iter->second;
66 }
67
68 iter = m_backends.find(backendID);
69 if (iter == m_backends.end()) {
70 LOGE("[BackendManager] GetBackend failed, not find backendId=%{public}zu", backendID);
71 return nullptr;
72 }
73
74 return iter->second;
75 }
76
GetBackendName(size_t backendID)77 const std::string& BackendManager::GetBackendName(size_t backendID)
78 {
79 const std::lock_guard<std::mutex> lock(m_mtx);
80 if (m_backendNames.empty()) {
81 LOGE("[BackendManager] GetBackendName failed, there is no registered backend can be used.");
82 return m_emptyBackendName;
83 }
84
85 auto iter = m_backendNames.begin();
86 if (backendID == static_cast<size_t>(0)) {
87 LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
88 } else {
89 iter = m_backendNames.find(backendID);
90 }
91
92 if (iter == m_backendNames.end()) {
93 LOGE("[BackendManager] GetBackendName failed, backendID %{public}zu is not registered.", backendID);
94 return m_emptyBackendName;
95 }
96
97 return iter->second;
98 }
99
RegisterBackend(const std::string & backendName,std::function<std::shared_ptr<Backend> ()> creator)100 OH_NN_ReturnCode BackendManager::RegisterBackend(
101 const std::string& backendName, std::function<std::shared_ptr<Backend>()> creator)
102 {
103 auto regBackend = creator();
104 if (regBackend == nullptr) {
105 LOGE("[BackendManager] RegisterBackend failed, fail to create backend.");
106 return OH_NN_FAILED;
107 }
108
109 if (!IsValidBackend(regBackend)) {
110 LOGE("[BackendManager] RegisterBackend failed, backend is not available.");
111 return OH_NN_UNAVAILABLE_DEVICE;
112 }
113
114 size_t backendID = regBackend->GetBackendID();
115
116 const std::lock_guard<std::mutex> lock(m_mtx);
117 auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
118 if (iter != m_backendIDs.end()) {
119 LOGE("[BackendManager] RegisterBackend failed, backend already exists, cannot register again. "
120 "backendID=%{public}zu", backendID);
121 return OH_NN_FAILED;
122 }
123
124 std::string tmpBackendName;
125 auto ret = regBackend->GetBackendName(tmpBackendName);
126 if (ret != OH_NN_SUCCESS) {
127 LOGE("[BackendManager] RegisterBackend failed, fail to get backend name.");
128 return OH_NN_FAILED;
129 }
130 m_backends.emplace(backendID, regBackend);
131 m_backendIDs.emplace_back(backendID);
132 m_backendNames.emplace(backendID, tmpBackendName);
133 if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
134 std::vector<size_t> backendIDsTmp {backendID};
135 m_backendIDGroup.emplace(backendName, backendIDsTmp);
136 } else {
137 m_backendIDGroup[backendName].emplace_back(backendID);
138 }
139 return OH_NN_SUCCESS;
140 }
141
RemoveBackend(const std::string & backendName)142 void BackendManager::RemoveBackend(const std::string& backendName)
143 {
144 LOGI("[RemoveBackend] start remove backend for %{public}s.", backendName.c_str());
145 const std::lock_guard<std::mutex> lock(m_mtx);
146 if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
147 LOGI("[RemoveBackend] No need to remove backend for %{public}s.", backendName.c_str());
148 return;
149 }
150
151 auto backendIDs = m_backendIDGroup[backendName];
152 for (auto backendID : backendIDs) {
153 if (m_backends.find(backendID) != m_backends.end()) {
154 m_backends.erase(backendID);
155 }
156 auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
157 if (iter != m_backendIDs.end()) {
158 m_backendIDs.erase(iter);
159 }
160 if (m_backendNames.find(backendID) != m_backendNames.end()) {
161 m_backendNames.erase(backendID);
162 }
163 LOGI("[RemoveBackend] remove backendID[%{public}zu] for %{public}s success.", backendID, backendName.c_str());
164 }
165 m_backendIDGroup.erase(backendName);
166 }
167
IsValidBackend(std::shared_ptr<Backend> backend) const168 bool BackendManager::IsValidBackend(std::shared_ptr<Backend> backend) const
169 {
170 DeviceStatus status = UNKNOWN;
171
172 OH_NN_ReturnCode ret = backend->GetBackendStatus(status);
173 if (ret != OH_NN_SUCCESS || status == UNKNOWN || status == OFFLINE) {
174 return false;
175 }
176
177 return true;
178 }
179 } // NeuralNetworkCore
180 } // OHOS
181