1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "backend_manager.h"
17 
18 #include <algorithm>
19 #include "cpp_type.h"
20 
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
~BackendManager()23 BackendManager::~BackendManager()
24 {
25     m_backends.clear();
26     m_backendNames.clear();
27     m_backendIDs.clear();
28     m_backendIDGroup.clear();
29 }
30 
GetInstance()31 BackendManager& BackendManager::GetInstance()
32 {
33     // if libneural_network_runtime.so loaded
34     if (dlopen("libneural_network_runtime.so", RTLD_NOLOAD) != nullptr) {
35         // if libneural_network_runtime_ext.so not loaded, try to dlopen it
36         if (dlopen("libneural_network_runtime_ext.so", RTLD_NOLOAD) == nullptr) {
37             LOGI("dlopen libneural_network_runtime_ext.so.");
38             void* libHandle = dlopen("libneural_network_runtime_ext.so", RTLD_NOW | RTLD_GLOBAL);
39             if (libHandle == nullptr) {
40                 LOGW("Failed to dlopen libneural_network_runtime_ext.so.");
41             }
42         }
43     }
44     static BackendManager instance;
45     return instance;
46 }
47 
GetAllBackendsID()48 const std::vector<size_t>& BackendManager::GetAllBackendsID()
49 {
50     const std::lock_guard<std::mutex> lock(m_mtx);
51     return m_backendIDs;
52 }
53 
GetBackend(size_t backendID)54 std::shared_ptr<Backend> BackendManager::GetBackend(size_t backendID)
55 {
56     const std::lock_guard<std::mutex> lock(m_mtx);
57     if (m_backends.empty()) {
58         LOGE("[BackendManager] GetBackend failed, there is no registered backend can be used.");
59         return nullptr;
60     }
61 
62     auto iter = m_backends.begin();
63     if (backendID == static_cast<size_t>(0)) {
64         LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
65         return iter->second;
66     }
67 
68     iter = m_backends.find(backendID);
69     if (iter == m_backends.end()) {
70         LOGE("[BackendManager] GetBackend failed, not find backendId=%{public}zu", backendID);
71         return nullptr;
72     }
73 
74     return iter->second;
75 }
76 
GetBackendName(size_t backendID)77 const std::string& BackendManager::GetBackendName(size_t backendID)
78 {
79     const std::lock_guard<std::mutex> lock(m_mtx);
80     if (m_backendNames.empty()) {
81         LOGE("[BackendManager] GetBackendName failed, there is no registered backend can be used.");
82         return m_emptyBackendName;
83     }
84 
85     auto iter = m_backendNames.begin();
86     if (backendID == static_cast<size_t>(0)) {
87         LOGI("[BackendManager] the backendID is 0, default return 1st backend.");
88     } else {
89         iter = m_backendNames.find(backendID);
90     }
91 
92     if (iter == m_backendNames.end()) {
93         LOGE("[BackendManager] GetBackendName failed, backendID %{public}zu is not registered.", backendID);
94         return m_emptyBackendName;
95     }
96 
97     return iter->second;
98 }
99 
RegisterBackend(const std::string & backendName,std::function<std::shared_ptr<Backend> ()> creator)100 OH_NN_ReturnCode BackendManager::RegisterBackend(
101     const std::string& backendName, std::function<std::shared_ptr<Backend>()> creator)
102 {
103     auto regBackend = creator();
104     if (regBackend == nullptr) {
105         LOGE("[BackendManager] RegisterBackend failed, fail to create backend.");
106         return OH_NN_FAILED;
107     }
108 
109     if (!IsValidBackend(regBackend)) {
110         LOGE("[BackendManager] RegisterBackend failed, backend is not available.");
111         return OH_NN_UNAVAILABLE_DEVICE;
112     }
113 
114     size_t backendID = regBackend->GetBackendID();
115 
116     const std::lock_guard<std::mutex> lock(m_mtx);
117     auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
118     if (iter != m_backendIDs.end()) {
119         LOGE("[BackendManager] RegisterBackend failed, backend already exists, cannot register again. "
120              "backendID=%{public}zu", backendID);
121         return OH_NN_FAILED;
122     }
123 
124     std::string tmpBackendName;
125     auto ret = regBackend->GetBackendName(tmpBackendName);
126     if (ret != OH_NN_SUCCESS) {
127         LOGE("[BackendManager] RegisterBackend failed, fail to get backend name.");
128         return OH_NN_FAILED;
129     }
130     m_backends.emplace(backendID, regBackend);
131     m_backendIDs.emplace_back(backendID);
132     m_backendNames.emplace(backendID, tmpBackendName);
133     if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
134         std::vector<size_t> backendIDsTmp {backendID};
135         m_backendIDGroup.emplace(backendName, backendIDsTmp);
136     } else {
137         m_backendIDGroup[backendName].emplace_back(backendID);
138     }
139     return OH_NN_SUCCESS;
140 }
141 
RemoveBackend(const std::string & backendName)142 void BackendManager::RemoveBackend(const std::string& backendName)
143 {
144     LOGI("[RemoveBackend] start remove backend for %{public}s.", backendName.c_str());
145     const std::lock_guard<std::mutex> lock(m_mtx);
146     if (m_backendIDGroup.find(backendName) == m_backendIDGroup.end()) {
147         LOGI("[RemoveBackend] No need to remove backend for %{public}s.", backendName.c_str());
148         return;
149     }
150 
151     auto backendIDs = m_backendIDGroup[backendName];
152     for (auto backendID : backendIDs) {
153         if (m_backends.find(backendID) != m_backends.end()) {
154             m_backends.erase(backendID);
155         }
156         auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID);
157         if (iter != m_backendIDs.end()) {
158             m_backendIDs.erase(iter);
159         }
160         if (m_backendNames.find(backendID) != m_backendNames.end()) {
161             m_backendNames.erase(backendID);
162         }
163         LOGI("[RemoveBackend] remove backendID[%{public}zu] for %{public}s success.", backendID, backendName.c_str());
164     }
165     m_backendIDGroup.erase(backendName);
166 }
167 
IsValidBackend(std::shared_ptr<Backend> backend) const168 bool BackendManager::IsValidBackend(std::shared_ptr<Backend> backend) const
169 {
170     DeviceStatus status = UNKNOWN;
171 
172     OH_NN_ReturnCode ret = backend->GetBackendStatus(status);
173     if (ret != OH_NN_SUCCESS || status == UNKNOWN || status == OFFLINE) {
174         return false;
175     }
176 
177     return true;
178 }
179 } // NeuralNetworkCore
180 } // OHOS
181