1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "communication_adapter/include/adapter_wrapper.h"
17 
18 #include <cstring>
19 #include <map>
20 
21 #include "communication_adapter/include/sa_async_handler.h"
22 #include "communication_adapter/include/sa_server_adapter.h"
23 #include "protocol/retcode_inner/aie_retcode_inner.h"
24 #include "utils/aie_macros.h"
25 #include "utils/constants/constants.h"
26 #include "utils/log/aie_log.h"
27 
28 using namespace OHOS::AI;
29 namespace {
30 constexpr int STARTING_CLIENT_ID = 1;
31 constexpr int MAX_NUM_CLIENTS = 1024;
32 using ServerAdapters = std::map<int, SaServerAdapter*>;
33 ServerAdapters g_saServerAdapters;
34 std::atomic<int> g_clientIdAtomic(0);
35 std::mutex g_serverAdapterMutex;
36 std::mutex g_connectMutex;
37 
FindValidClientId()38 int FindValidClientId()
39 {
40     if (g_saServerAdapters.size() > MAX_NUM_CLIENTS) {
41         HILOGE("[AdapterWrapper]Num of valid clients reaches max.");
42         return INVALID_CLIENT_ID;
43     }
44 
45     do {
46         ++g_clientIdAtomic;
47         if (g_clientIdAtomic < STARTING_CLIENT_ID) {
48             HILOGI("[AdapterWrapper]Client id reaches max, reset to starting value [%d].", STARTING_CLIENT_ID);
49             g_clientIdAtomic = STARTING_CLIENT_ID;
50         }
51     } while (g_saServerAdapters.find(g_clientIdAtomic) != g_saServerAdapters.end());
52 
53     return g_clientIdAtomic;
54 }
55 
AllocateClientAdapter()56 int AllocateClientAdapter()
57 {
58     SaServerAdapter *adapter = nullptr;
59     AIE_NEW(adapter, SaServerAdapter(g_clientIdAtomic));
60     if (adapter == nullptr) {
61         HILOGE("[AdapterWrapper]Failed to new adapter.");
62         return INVALID_CLIENT_ID;
63     }
64 
65     int clientId = adapter->GetAdapterId();
66 
67     std::lock_guard<std::mutex> guard(g_serverAdapterMutex);
68     g_saServerAdapters[clientId] = adapter;
69     return clientId;
70 }
71 }
72 
73 class AdapterWrapper {
74 public:
AdapterWrapper(SaServerAdapter * adapter)75     explicit AdapterWrapper(SaServerAdapter *adapter) : adapter_(adapter)
76     {
77         if (adapter_) {
78             adapter_->IncRef();
79         }
80     }
81 
~AdapterWrapper()82     ~AdapterWrapper()
83     {
84         if (adapter_) {
85             adapter_->DecRef();
86             adapter_ = nullptr;
87         }
88     }
89 
90 private:
91     SaServerAdapter *adapter_ = nullptr;
92 };
93 
FindAdapter(const int clientId)94 SaServerAdapter* FindAdapter(const int clientId)
95 {
96     std::lock_guard<std::mutex> guard(g_serverAdapterMutex);
97     ServerAdapters::iterator iter = g_saServerAdapters.find(clientId);
98     if (iter != g_saServerAdapters.end()) {
99         return iter->second;
100     }
101     return nullptr;
102 }
103 
GenerateClient()104 int GenerateClient()
105 {
106     HILOGI("[AdapterWrapper]Begin to call GenerateClient.");
107     std::lock_guard<std::mutex> guard(g_connectMutex);
108 
109     if (FindValidClientId() == INVALID_CLIENT_ID) {
110         return INVALID_CLIENT_ID;
111     }
112 
113     return AllocateClientAdapter();
114 }
115 
SyncExecAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo,DataInfo * outputInfo)116 int SyncExecAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo,
117     DataInfo *outputInfo)
118 {
119     HILOGI("[AdapterWrapper]Begin to call SyncExecAlgoWrapper.");
120     if (clientInfo == nullptr || algoInfo == nullptr) {
121         HILOGE("[AdapterWrapper]The clientInfo or algoInfo is nullptr");
122         return RETCODE_NULL_PARAM;
123     }
124 
125     if (algoInfo->isAsync) {
126         HILOGW("[AdapterWrapper]SyncExecute but the algoInfo is AsyncExecute");
127         return RETCODE_WRONG_INFER_MODE;
128     }
129 
130     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
131     if (adapter == nullptr) {
132         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
133         return RETCODE_NO_CLIENT_FOUND;
134     }
135 
136     AdapterWrapper adapterGuard(adapter);
137     return adapter->SyncExecute(*clientInfo, *algoInfo, *inputInfo, *outputInfo);
138 }
139 
AsyncExecAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo)140 int AsyncExecAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo)
141 {
142     HILOGI("[AdapterWrapper]Begin to call AsyncExecAlgoWrapper.");
143     if (clientInfo == nullptr || algoInfo == nullptr) {
144         HILOGE("[AdapterWrapper]The clientInfo or algoInfo is nullptr.");
145         return RETCODE_NULL_PARAM;
146     }
147 
148     if (!algoInfo->isAsync) {
149         HILOGW("[AdapterWrapper]AsyncExecute but the algoInfo is SyncExecute.");
150         return RETCODE_WRONG_INFER_MODE;
151     }
152 
153     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
154     if (adapter == nullptr) {
155         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
156         return RETCODE_NO_CLIENT_FOUND;
157     }
158     AdapterWrapper adapterGuard(adapter);
159 
160     return adapter->AsyncExecute(*clientInfo, *algoInfo, *inputInfo);
161 }
162 
LoadAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo,DataInfo * outputInfo)163 int LoadAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo,
164     DataInfo *outputInfo)
165 {
166     HILOGI("[AdapterWrapper]Begin to call LoadAlgoWrapper.");
167     if (clientInfo == nullptr || algoInfo == nullptr) {
168         HILOGE("[AdapterWrapper]The clientInfo or algoInfo is null");
169         return RETCODE_NULL_PARAM;
170     }
171 
172     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
173     if (adapter == nullptr) {
174         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
175         return RETCODE_NO_CLIENT_FOUND;
176     }
177 
178     AdapterWrapper adapterGuard(adapter);
179     long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
180     int retCode = adapter->LoadAlgorithm(transactionId, *algoInfo, *inputInfo, *outputInfo);
181     if (retCode != RETCODE_SUCCESS) {
182         HILOGE("[AdapterWrapper][transactionId:%lld]Failed to load algorithm, retCode[%d], aid[%d].",
183             transactionId, retCode, algoInfo->algorithmType);
184         return retCode;
185     }
186 
187     if (algoInfo->isAsync) {
188         SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
189         CHK_RET(saAsyncHandler == nullptr, RETCODE_NULL_PARAM);
190         retCode = saAsyncHandler->StartAsyncTransaction(transactionId, clientInfo->clientId);
191         HILOGI("[AdapterWrapper]StartAsyncTransaction retCode is [%d].", retCode);
192     }
193     return retCode;
194 }
195 
UnloadAlgoWrapper(const ClientInfo * clientInfo,const AlgorithmInfo * algoInfo,const DataInfo * inputInfo)196 int UnloadAlgoWrapper(const ClientInfo *clientInfo, const AlgorithmInfo *algoInfo, const DataInfo *inputInfo)
197 {
198     HILOGI("[AdapterWrapper]Begin to call UnloadAlgoWrapper.");
199     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
200     if (adapter == nullptr) {
201         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
202         return RETCODE_NO_CLIENT_FOUND;
203     }
204 
205     AdapterWrapper adapterGuard(adapter);
206     long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
207     if (algoInfo == nullptr) {
208         HILOGE("[AdapterWrapper]AlgoInfo is nullptr.");
209         return RETCODE_NULL_PARAM;
210     }
211     if (algoInfo->isAsync) {
212         SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
213         if (saAsyncHandler != nullptr) {
214             saAsyncHandler->StopAsyncTransaction(transactionId);
215         }
216     }
217 
218     return adapter->UnloadAlgorithm(transactionId, *inputInfo);
219 }
220 
RemoveAdapterWrapper(const ClientInfo * clientInfo)221 int RemoveAdapterWrapper(const ClientInfo *clientInfo)
222 {
223     HILOGI("[AdapterWrapper]Begin to call RemoveAdapterWrapper.");
224     std::lock_guard<std::mutex> guard(g_serverAdapterMutex);
225     ServerAdapters::iterator iter = g_saServerAdapters.find(clientInfo->clientId);
226     if (iter == g_saServerAdapters.end()) {
227         HILOGE("[AdapterWrapper]Failed to find serverAdapter for client[%d].", clientInfo->clientId);
228         return RETCODE_FAILURE;
229     }
230 
231     AIE_DELETE(iter->second);
232     g_saServerAdapters.erase(iter);
233     return RETCODE_SUCCESS;
234 }
235 
SetOptionWrapper(const ClientInfo * clientInfo,int optionType,const DataInfo * inputInfo)236 int SetOptionWrapper(const ClientInfo *clientInfo, int optionType, const DataInfo *inputInfo)
237 {
238     HILOGI("[AdapterWrapper]Begin to call SetOptionWrapper.");
239     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
240     if (adapter == nullptr) {
241         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
242         return RETCODE_NO_CLIENT_FOUND;
243     }
244 
245     AdapterWrapper adapterGuard(adapter);
246     long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
247     return adapter->SetOption(transactionId, optionType, *inputInfo);
248 }
249 
GetOptionWrapper(const ClientInfo * clientInfo,int optionType,const DataInfo * inputInfo,DataInfo * outputInfo)250 int GetOptionWrapper(const ClientInfo *clientInfo, int optionType, const DataInfo *inputInfo, DataInfo *outputInfo)
251 {
252     HILOGI("[AdapterWrapper]Begin to call GetOptionWrapper.");
253     if (clientInfo == nullptr) {
254         HILOGE("[AdapterWrapper]ClientInfo is nullptr.");
255         return RETCODE_NULL_PARAM;
256     }
257     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
258     if (adapter == nullptr) {
259         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
260         return RETCODE_NO_CLIENT_FOUND;
261     }
262 
263     AdapterWrapper adapterGuard(adapter);
264     long long transactionId = adapter->GetTransactionId(clientInfo->sessionId);
265     return adapter->GetOption(transactionId, optionType, *inputInfo, *outputInfo);
266 }
267 
RegisterCallbackWrapper(const ClientInfo * clientInfo,SvcIdentity * sid)268 int RegisterCallbackWrapper(const ClientInfo *clientInfo, SvcIdentity *sid)
269 {
270     HILOGI("[AdapterWrapper]Begin to call RegisterCallbackWrapper.");
271     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
272     if (adapter == nullptr) {
273         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
274         return RETCODE_NO_CLIENT_FOUND;
275     }
276     AdapterWrapper adapterGuard(adapter);
277 
278     adapter->SaveEngineListener(sid);
279 
280     SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
281     CHK_RET(saAsyncHandler == nullptr, RETCODE_NULL_PARAM);
282     int retCode = saAsyncHandler->RegisterAsyncHandler(clientInfo->clientId);
283     if (retCode != RETCODE_SUCCESS) {
284         HILOGE("[AdapterWrapper]Client[%d] session[%d] RegisterAsyncHandler result is [%d].", clientInfo->clientId,
285             clientInfo->sessionId, retCode);
286         return retCode;
287     }
288 
289     return saAsyncHandler->StartAsyncProcess(clientInfo->clientId, adapter);
290 }
291 
UnregisterCallbackWrapper(const ClientInfo * clientInfo)292 int UnregisterCallbackWrapper(const ClientInfo *clientInfo)
293 {
294     HILOGI("[AdapterWrapper]Begin to call UnregisterCallbackWrapper.");
295     SaServerAdapter *adapter = FindAdapter(clientInfo->clientId);
296     if (adapter == nullptr) {
297         HILOGE("[AdapterWrapper]No adapter found for client[%d].", clientInfo->clientId);
298         return RETCODE_NO_CLIENT_FOUND;
299     }
300 
301     AdapterWrapper adapterGuard(adapter);
302     SaAsyncHandler *saAsyncHandler = SaAsyncHandler::GetInstance();
303     CHK_RET(saAsyncHandler == nullptr, RETCODE_NULL_PARAM);
304     saAsyncHandler->StopAsyncProcess(clientInfo->clientId);
305     adapter->ClearEngineListener();
306     return RETCODE_SUCCESS;
307 }
308