1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "communication_adapter/include/adapter_wrapper.h"
17
18 #include <cstring>
19 #include <map>
20
21 #include "communication_adapter/include/sa_async_handler.h"
22 #include "communication_adapter/include/sa_server_adapter.h"
23 #include "protocol/retcode_inner/aie_retcode_inner.h"
24 #include "utils/aie_macros.h"
25 #include "utils/constants/constants.h"
26 #include "utils/log/aie_log.h"
27
28 using namespace OHOS::AI;
29 namespace {
30 constexpr int STARTING_CLIENT_ID = 1;
31 constexpr int MAX_NUM_CLIENTS = 1024;
32 using ServerAdapters = std::map<int, SaServerAdapter*>;
33 ServerAdapters g_saServerAdapters;
34 std::atomic<int> g_clientIdAtomic(0);
35 std::mutex g_serverAdapterMutex;
36 std::mutex g_connectMutex;
37
FindValidClientId()38 int FindValidClientId()
39 {
40 if (g_saServerAdapters.size() > MAX_NUM_CLIENTS) {
41 HILOGE("[AdapterWrapper]Num of valid clients reaches max.");
42 return INVALID_CLIENT_ID;
43 }
44
45 do {
46 ++g_clientIdAtomic;
47 if (g_clientIdAtomic < STARTING_CLIENT_ID) {
48 HILOGI("[AdapterWrapper]Client id reaches max, reset to starting value [%d].", STARTING_CLIENT_ID);
49 g_clientIdAtomic = STARTING_CLIENT_ID;
50 }
51 } while (g_saServerAdapters.find(g_clientIdAtomic) != g_saServerAdapters.end());
52
53 return g_clientIdAtomic;
54 }
55
AllocateClientAdapter()56 int AllocateClientAdapter()
57 {
58 SaServerAdapter *adapter = nullptr;
59 AIE_NEW(adapter, SaServerAdapter(g_clientIdAtomic));
60 if (adapter == nullptr) {
61 HILOGE("[AdapterWrapper]Failed to new adapter.");
62 return INVALID_CLIENT_ID;
63 }
64
65 int clientId = adapter->GetAdapterId();
66
67 std::lock_guard<std::mutex> guard(g_serverAdapterMutex);
68 g_saServerAdapters[clientId] = adapter;
69 return clientId;
70 }
71 }
72
73 class AdapterWrapper {
74 public:
AdapterWrapper(SaServerAdapter * adapter)75 explicit AdapterWrapper(SaServerAdapter *adapter) : adapter_(adapter)
76 {
77 if (adapter_) {
78 adapter_->IncRef();
79 }
80 }
81
~AdapterWrapper()82 ~AdapterWrapper()
83 {
84 if (adapter_) {
85 adapter_->DecRef();
86 adapter_ = nullptr;
87 }
88 }
89
90 private:
91 SaServerAdapter *adapter_ = nullptr;
92 };
93
FindAdapter(const int clientId)94 SaServerAdapter* FindAdapter(const int clientId)
95 {
96 std::lock_guard<std::mutex> guard(g_serverAdapterMutex);
97 ServerAdapters::iterator iter = g_saServerAdapters.find(clientId);
98 if (iter != g_saServerAdapters.end()) {
99 return iter->second;
100 }
101 return nullptr;
102 }
103
GenerateClient()104 int GenerateClient()
105 {
106 HILOGI("[AdapterWrapper]Begin to call GenerateClient.");
107 std::lock_guard<std::mutex> guard(g_connectMutex);
108
109 if (FindValidClientId() == INVALID_CLIENT_ID) {
110 return INVALID_CLIENT_ID;
111 }
112
113 return AllocateClientAdapter();
114 }
115
SyncExecAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo,DataInfo * outputInfo)116 int SyncExecAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo,
117 DataInfo *outputInfo)
118 {
119 HILOGI("[AdapterWrapper]Begin to call SyncExecAlgoWrapper.");
120 if (clientInfo == nullptr || algoInfo == nullptr) {
121 HILOGE("[AdapterWrapper]The clientInfo or algoInfo is nullptr");
122 return RETCODE_NULL_PARAM;
123 }
124
125 if (algoInfo->isAsync) {
126 HILOGW("[AdapterWrapper]SyncExecute but the algoInfo is AsyncExecute");
127 return RETCODE_WRONG_INFER_MODE;
128 }
129
130 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
131 if (adapter == nullptr) {
132 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
133 return RETCODE_NO_CLIENT_FOUND;
134 }
135
136 AdapterWrapper adapterGuard(adapter);
137 return adapter->SyncExecute(*clientInfo, *algoInfo, *inputInfo, *outputInfo);
138 }
139
AsyncExecAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo)140 int AsyncExecAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo)
141 {
142 HILOGI("[AdapterWrapper]Begin to call AsyncExecAlgoWrapper.");
143 if (clientInfo == nullptr || algoInfo == nullptr) {
144 HILOGE("[AdapterWrapper]The clientInfo or algoInfo is nullptr.");
145 return RETCODE_NULL_PARAM;
146 }
147
148 if (!algoInfo->isAsync) {
149 HILOGW("[AdapterWrapper]AsyncExecute but the algoInfo is SyncExecute.");
150 return RETCODE_WRONG_INFER_MODE;
151 }
152
153 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
154 if (adapter == nullptr) {
155 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
156 return RETCODE_NO_CLIENT_FOUND;
157 }
158 AdapterWrapper adapterGuard(adapter);
159
160 return adapter->AsyncExecute(*clientInfo, *algoInfo, *inputInfo);
161 }
162
LoadAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo,DataInfo * outputInfo)163 int LoadAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo,
164 DataInfo *outputInfo)
165 {
166 HILOGI("[AdapterWrapper]Begin to call LoadAlgoWrapper.");
167 if (clientInfo == nullptr || algoInfo == nullptr) {
168 HILOGE("[AdapterWrapper]The clientInfo or algoInfo is null");
169 return RETCODE_NULL_PARAM;
170 }
171
172 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
173 if (adapter == nullptr) {
174 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
175 return RETCODE_NO_CLIENT_FOUND;
176 }
177
178 AdapterWrapper adapterGuard(adapter);
179 long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
180 int retCode = adapter->LoadAlgorithm(transactionId, *algoInfo, *inputInfo, *outputInfo);
181 if (retCode != RETCODE_SUCCESS) {
182 HILOGE("[AdapterWrapper][transactionId:%lld]Failed to load algorithm, retCode[%d], aid[%d].",
183 transactionId, retCode, algoInfo->algorithmType);
184 return retCode;
185 }
186
187 if (algoInfo->isAsync) {
188 SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
189 CHK_RET(saAsyncHandler == nullptr, RETCODE_NULL_PARAM);
190 retCode = saAsyncHandler->StartAsyncTransaction(transactionId, clientInfo->clientId);
191 HILOGI("[AdapterWrapper]StartAsyncTransaction retCode is [%d].", retCode);
192 }
193 return retCode;
194 }
195
UnloadAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo)196 int UnloadAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo)
197 {
198 HILOGI("[AdapterWrapper]Begin to call UnloadAlgoWrapper.");
199 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
200 if (adapter == nullptr) {
201 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
202 return RETCODE_NO_CLIENT_FOUND;
203 }
204
205 AdapterWrapper adapterGuard(adapter);
206 long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
207 if (algoInfo == nullptr) {
208 HILOGE("[AdapterWrapper]AlgoInfo is nullptr.");
209 return RETCODE_NULL_PARAM;
210 }
211 if (algoInfo->isAsync) {
212 SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
213 if (saAsyncHandler != nullptr) {
214 saAsyncHandler->StopAsyncTransaction(transactionId);
215 }
216 }
217
218 return adapter->UnloadAlgorithm(transactionId, *inputInfo);
219 }
220
RemoveAdapterWrapper(const ClientInfo * clientInfo)221 int RemoveAdapterWrapper(const ClientInfo *clientInfo)
222 {
223 HILOGI("[AdapterWrapper]Begin to call RemoveAdapterWrapper.");
224 std::lock_guard<std::mutex> guard(g_serverAdapterMutex);
225 ServerAdapters::iterator iter = g_saServerAdapters.find(clientInfo->clientId);
226 if (iter == g_saServerAdapters.end()) {
227 HILOGE("[AdapterWrapper]Failed to find serverAdapter for client[%d].", clientInfo->clientId);
228 return RETCODE_FAILURE;
229 }
230
231 AIE_DELETE(iter->second);
232 g_saServerAdapters.erase(iter);
233 return RETCODE_SUCCESS;
234 }
235
SetOptionWrapper(const ClientInfo * clientInfo,int optionType,const DataInfo * inputInfo)236 int SetOptionWrapper(const ClientInfo *clientInfo, int optionType, const DataInfo *inputInfo)
237 {
238 HILOGI("[AdapterWrapper]Begin to call SetOptionWrapper.");
239 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
240 if (adapter == nullptr) {
241 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
242 return RETCODE_NO_CLIENT_FOUND;
243 }
244
245 AdapterWrapper adapterGuard(adapter);
246 long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
247 return adapter->SetOption(transactionId, optionType, *inputInfo);
248 }
249
GetOptionWrapper(const ClientInfo * clientInfo,int optionType,const DataInfo * inputInfo,DataInfo * outputInfo)250 int GetOptionWrapper(const ClientInfo *clientInfo, int optionType, const DataInfo *inputInfo, DataInfo *outputInfo)
251 {
252 HILOGI("[AdapterWrapper]Begin to call GetOptionWrapper.");
253 if (clientInfo == nullptr) {
254 HILOGE("[AdapterWrapper]ClientInfo is nullptr.");
255 return RETCODE_NULL_PARAM;
256 }
257 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
258 if (adapter == nullptr) {
259 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
260 return RETCODE_NO_CLIENT_FOUND;
261 }
262
263 AdapterWrapper adapterGuard(adapter);
264 long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
265 return adapter->GetOption(transactionId, optionType, *inputInfo, *outputInfo);
266 }
267
RegisterCallbackWrapper(const ClientInfo * clientInfo,SvcIdentity * sid)268 int RegisterCallbackWrapper(const ClientInfo *clientInfo, SvcIdentity *sid)
269 {
270 HILOGI("[AdapterWrapper]Begin to call RegisterCallbackWrapper.");
271 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
272 if (adapter == nullptr) {
273 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
274 return RETCODE_NO_CLIENT_FOUND;
275 }
276 AdapterWrapper adapterGuard(adapter);
277
278 adapter->SaveEngineListener(sid);
279
280 SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
281 CHK_RET(saAsyncHandler == nullptr, RETCODE_NULL_PARAM);
282 int retCode = saAsyncHandler->RegisterAsyncHandler(clientInfo->clientId);
283 if (retCode != RETCODE_SUCCESS) {
284 HILOGE("[AdapterWrapper]Client[%d] session[%d] RegisterAsyncHandler result is [%d].", clientInfo->clientId,
285 clientInfo->sessionId, retCode);
286 return retCode;
287 }
288
289 return saAsyncHandler->StartAsyncProcess(clientInfo->clientId, adapter);
290 }
291
UnregisterCallbackWrapper(const ClientInfo * clientInfo)292 int UnregisterCallbackWrapper(const ClientInfo *clientInfo)
293 {
294 HILOGI("[AdapterWrapper]Begin to call UnregisterCallbackWrapper.");
295 SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
296 if (adapter == nullptr) {
297 HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
298 return RETCODE_NO_CLIENT_FOUND;
299 }
300
301 AdapterWrapper adapterGuard(adapter);
302 SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
303 CHK_RET(saAsyncHandler == nullptr, RETCODE_NULL_PARAM);
304 saAsyncHandler->StopAsyncProcess(clientInfo->clientId);
305 adapter->ClearEngineListener();
306 return RETCODE_SUCCESS;
307 }
308