1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef LOCAL_SOCKET_SERVER_CONTEXT_H
17 #define LOCAL_SOCKET_SERVER_CONTEXT_H
18 
19 #include <cstddef>
20 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
21 #include <unistd.h>
22 #endif
23 #include <map>
24 #if !defined(MAC_PLATFORM) && !defined(IOS_PLATFORM)
25 #include <sys/epoll.h>
26 #endif
27 #include <unistd.h>
28 
29 #include "base_context.h"
30 #include "event_list.h"
31 #include "local_socket_context.h"
32 #include "napi/native_api.h"
33 #include "nocopyable.h"
34 #include "socket_state_base.h"
35 
36 namespace OHOS::NetStack::Socket {
37 struct LocalSocketServerManager : public SocketBaseManager {
38     static constexpr int MAX_EVENTS = 10;
39     static constexpr int EPOLL_TIMEOUT_MS = 500;
40     int clientId_ = 0;
41 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
42     int threadCounts_ = 0;
43 #endif
44     LocalExtraOptions extraOptions_;
45     bool alreadySetExtraOptions_ = false;
46     std::atomic_bool isServerDestruct_;
47     bool isLoopFinished_ = false;
48     int epollFd_ = 0;
49 #if !defined(MAC_PLATFORM) && !defined(IOS_PLATFORM)
50     epoll_event events_[MAX_EVENTS] = {};
51 #endif
52     std::mutex finishMutex_;
53     std::condition_variable finishCond_;
54     std::mutex clientMutex_;
55     std::condition_variable cond_;
56     std::map<int, int> acceptFds_;                      // id & fd
57     std::map<int, EventManager *> clientEventManagers_; // id & EventManager*
LocalSocketServerManagerLocalSocketServerManager58     explicit LocalSocketServerManager(int sockfd) : SocketBaseManager(sockfd) {}
59 
SetServerDestructStatusLocalSocketServerManager60     void SetServerDestructStatus(bool flag)
61     {
62         isServerDestruct_.store(flag, std::memory_order_relaxed);
63     }
GetServerDestructStatusLocalSocketServerManager64     bool GetServerDestructStatus()
65     {
66         return isServerDestruct_.load(std::memory_order_relaxed);
67     }
68 #if !defined(MAC_PLATFORM) && !defined(IOS_PLATFORM)
StartEpollLocalSocketServerManager69     int StartEpoll()
70     {
71         epollFd_ = epoll_create1(0);
72         return epollFd_;
73     }
EpollWaitLocalSocketServerManager74     int EpollWait()
75     {
76         return epoll_wait(epollFd_, events_, MAX_EVENTS - 1, EPOLL_TIMEOUT_MS);
77     }
RegisterEpollEventLocalSocketServerManager78     int RegisterEpollEvent(int sockfd, int events)
79     {
80         epoll_event event;
81         event.events = events;
82         event.data.fd = sockfd;
83         return epoll_ctl(epollFd_, EPOLL_CTL_ADD, sockfd, &event);
84     }
WaitRegisteringEventLocalSocketServerManager85     void WaitRegisteringEvent(int id)
86     {
87         std::unique_lock<std::mutex> lock(clientMutex_);
88         cond_.wait(lock, [&id, this]() {
89             if (auto iter = clientEventManagers_.find(id); iter != clientEventManagers_.end()) {
90                 if (iter->second->HasEventListener(EVENT_MESSAGE)) {
91                     return true;
92                 }
93             }
94             return false;
95         });
96     }
GetClientIdLocalSocketServerManager97     int GetClientId(int fd)
98     {
99         std::lock_guard<std::mutex> lock(clientMutex_);
100         for (const auto &[clientId, connectFd] : acceptFds_) {
101             if (fd == connectFd) {
102                 return clientId;
103             }
104         }
105         return -1;
106     }
GetManagerLocalSocketServerManager107     EventManager *GetManager(int id)
108     {
109         std::lock_guard<std::mutex> lock(clientMutex_);
110         if (auto ite = clientEventManagers_.find(id); ite != clientEventManagers_.end()) {
111             return ite->second;
112         }
113         return nullptr;
114     }
115 #endif
AddAcceptLocalSocketServerManager116     int AddAccept(int accpetFd)
117     {
118         std::lock_guard<std::mutex> lock(clientMutex_);
119         auto res = acceptFds_.emplace(++clientId_, accpetFd);
120         return res.second ? clientId_ : -1;
121     }
RemoveAllAcceptLocalSocketServerManager122     void RemoveAllAccept()
123     {
124         std::lock_guard<std::mutex> lock(clientMutex_);
125         for (const auto &[id, fd] : acceptFds_) {
126             if (fd > 0) {
127                 close(fd);
128             }
129         }
130         acceptFds_.clear();
131     }
RemoveAcceptLocalSocketServerManager132     void RemoveAccept(int clientId)
133     {
134         std::lock_guard<std::mutex> lock(clientMutex_);
135         if (auto ite = acceptFds_.find(clientId); ite != acceptFds_.end()) {
136 #if !defined(MAC_PLATFORM) && !defined(IOS_PLATFORM)
137             epoll_ctl(epollFd_, EPOLL_CTL_DEL, ite->second, nullptr);
138 #endif
139             close(ite->second);
140             acceptFds_.erase(ite);
141         }
142     }
GetAcceptFdLocalSocketServerManager143     int GetAcceptFd(int clientId)
144     {
145         std::lock_guard<std::mutex> lock(clientMutex_);
146         if (auto ite = acceptFds_.find(clientId); ite != acceptFds_.end()) {
147             return ite->second;
148         }
149         return -1;
150     }
GetClientCountsLocalSocketServerManager151     size_t GetClientCounts()
152     {
153         std::lock_guard<std::mutex> lock(clientMutex_);
154         return acceptFds_.size();
155     }
156 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
WaitForManagerLocalSocketServerManager157     EventManager *WaitForManager(int clientId)
158     {
159         EventManager *manager = nullptr;
160         std::unique_lock<std::mutex> lock(clientMutex_);
161         cond_.wait(lock, [&manager, &clientId, this]() {
162             if (auto iter = clientEventManagers_.find(clientId); iter != clientEventManagers_.end()) {
163                 manager = iter->second;
164                 if (manager->HasEventListener(EVENT_MESSAGE)) {
165                     return true;
166                 }
167             }
168             return false;
169         });
170         return manager;
171     }
172 #endif
NotifyRegisterEventLocalSocketServerManager173     void NotifyRegisterEvent()
174     {
175         std::lock_guard<std::mutex> lock(clientMutex_);
176         cond_.notify_one();
177     }
AddEventManagerLocalSocketServerManager178     void AddEventManager(int clientId, EventManager *manager)
179     {
180         std::lock_guard<std::mutex> lock(clientMutex_);
181         clientEventManagers_.insert(std::make_pair(clientId, manager));
182         cond_.notify_one();
183     }
RemoveEventManagerLocalSocketServerManager184     void RemoveEventManager(int clientId)
185     {
186         std::lock_guard<std::mutex> lock(clientMutex_);
187         if (auto ite = clientEventManagers_.find(clientId); ite != clientEventManagers_.end()) {
188             EventManager::SetInvalid(ite->second);
189             clientEventManagers_.erase(ite);
190         }
191     }
RemoveAllEventManagerLocalSocketServerManager192     void RemoveAllEventManager()
193     {
194         std::lock_guard<std::mutex> lock(clientMutex_);
195         for (const auto &[id, manager] : clientEventManagers_) {
196             EventManager::SetInvalid(manager);
197         }
198         clientEventManagers_.clear();
199     }
200 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
IncreaseThreadCountsLocalSocketServerManager201     void IncreaseThreadCounts()
202     {
203         std::lock_guard<std::mutex> lock(finishMutex_);
204         ++threadCounts_;
205     }
206 #endif
NotifyLoopFinishedLocalSocketServerManager207     void NotifyLoopFinished()
208     {
209         std::lock_guard<std::mutex> lock(finishMutex_);
210         isLoopFinished_ = true;
211         finishCond_.notify_one();
212     }
WaitForEndingLoopLocalSocketServerManager213     void WaitForEndingLoop()
214     {
215         std::unique_lock<std::mutex> lock(finishMutex_);
216         finishCond_.wait(lock, [this]() {
217             return isLoopFinished_;
218         });
219     }
220 };
221 
222 class LocalSocketServerBaseContext : public LocalSocketBaseContext {
223 public:
LocalSocketServerBaseContext(napi_env env,EventManager * manager)224     LocalSocketServerBaseContext(napi_env env, EventManager *manager) : LocalSocketBaseContext(env, manager) {}
225     [[nodiscard]] int GetSocketFd() const override;
226     void SetSocketFd(int sock) override;
227 };
228 
229 class LocalSocketServerListenContext final : public LocalSocketServerBaseContext {
230 public:
LocalSocketServerListenContext(napi_env env,EventManager * manager)231     LocalSocketServerListenContext(napi_env env, EventManager *manager) : LocalSocketServerBaseContext(env, manager) {}
232     void ParseParams(napi_value *params, size_t paramsCount) override;
233     const std::string &GetSocketPath() const;
234 
235 private:
236     std::string socketPath_;
237 };
238 
239 class LocalSocketServerEndContext final : public LocalSocketServerBaseContext {
240 public:
LocalSocketServerEndContext(napi_env env,EventManager * manager)241     LocalSocketServerEndContext(napi_env env, EventManager *manager) : LocalSocketServerBaseContext(env, manager) {}
242     void ParseParams(napi_value *params, size_t paramsCount) override;
243 };
244 
245 class LocalSocketServerGetStateContext final : public LocalSocketServerBaseContext {
246 public:
LocalSocketServerGetStateContext(napi_env env,EventManager * manager)247     LocalSocketServerGetStateContext(napi_env env, EventManager *manager) : LocalSocketServerBaseContext(env, manager)
248     {
249     }
250     void ParseParams(napi_value *params, size_t paramsCount) override;
251     SocketStateBase &GetStateRef();
252 
253 private:
254     SocketStateBase state_;
255 };
256 
257 class LocalSocketServerGetLocalAddressContext final : public LocalSocketServerBaseContext {
258 public:
LocalSocketServerGetLocalAddressContext(napi_env env,EventManager * manager)259     LocalSocketServerGetLocalAddressContext(napi_env env, EventManager *manager)
260         : LocalSocketServerBaseContext(env, manager) {}
261     void ParseParams(napi_value *params, size_t paramsCount) override;
262     void SetSocketPath(const std::string socketPath);
263     std::string GetSocketPath();
264     int GetClientId() const;
265     void SetClientId(int clientId);
266 
267 private:
268     std::string socketPath_;
269     int clientId_ = 0;
270 };
271 
272 class LocalSocketServerSetExtraOptionsContext final : public LocalSocketServerBaseContext {
273 public:
LocalSocketServerSetExtraOptionsContext(napi_env env,EventManager * manager)274     LocalSocketServerSetExtraOptionsContext(napi_env env, EventManager *manager)
275         : LocalSocketServerBaseContext(env, manager)
276     {
277     }
278     void ParseParams(napi_value *params, size_t paramsCount) override;
279     LocalExtraOptions &GetOptionsRef();
280 
281 private:
282     LocalExtraOptions options_;
283 };
284 
285 class LocalSocketServerGetExtraOptionsContext final : public LocalSocketServerBaseContext {
286 public:
LocalSocketServerGetExtraOptionsContext(napi_env env,EventManager * manager)287     LocalSocketServerGetExtraOptionsContext(napi_env env, EventManager *manager)
288         : LocalSocketServerBaseContext(env, manager)
289     {
290     }
291     void ParseParams(napi_value *params, size_t paramsCount) override;
292     LocalExtraOptions &GetOptionsRef();
293 
294 private:
295     LocalExtraOptions options_;
296 };
297 
298 class LocalSocketServerSendContext final : public LocalSocketServerBaseContext {
299 public:
LocalSocketServerSendContext(napi_env env,EventManager * manager)300     LocalSocketServerSendContext(napi_env env, EventManager *manager) : LocalSocketServerBaseContext(env, manager) {}
301     void ParseParams(napi_value *params, size_t paramsCount) override;
302     int GetAcceptFd();
303     LocalSocketOptions &GetOptionsRef();
304     int GetClientId() const;
305     void SetClientId(int clientId);
306 
307 private:
308     bool GetData(napi_value sendOptions);
309     LocalSocketOptions options_;
310     int clientId_ = 0;
311 };
312 
313 class LocalSocketServerCloseContext final : public LocalSocketServerBaseContext {
314 public:
LocalSocketServerCloseContext(napi_env env,EventManager * manager)315     LocalSocketServerCloseContext(napi_env env, EventManager *manager) : LocalSocketServerBaseContext(env, manager) {}
316     void ParseParams(napi_value *params, size_t paramsCount) override;
317     int GetClientId() const;
318     void SetClientId(int clientId);
319 
320 private:
321     int clientId_ = 0;
322 };
323 } // namespace OHOS::NetStack::Socket
324 #endif
325