1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
17 #define COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
18 
19 #include "event_manager.h"
20 #include "extra_options_base.h"
21 #include "net_address.h"
22 #include "socket_error.h"
23 #include "socket_remote_info.h"
24 #include "socket_state_base.h"
25 #include "tcp_connect_options.h"
26 #include "tcp_extra_options.h"
27 #include "tcp_send_options.h"
28 #include "tls.h"
29 #include "tls_certificate.h"
30 #include "tls_configuration.h"
31 #include "tls_context_server.h"
32 #include "tls_key.h"
33 #include "tls_socket.h"
34 #include <any>
35 #include <condition_variable>
36 #include <cstring>
37 #include <functional>
38 #include <map>
39 #include <poll.h>
40 #include <thread>
41 #include <tuple>
42 #include <unistd.h>
43 #include <vector>
44 
45 namespace OHOS {
46 namespace NetStack {
47 namespace TlsSocketServer {
48 constexpr int USER_LIMIT = 10;
49 struct CacheInfo {
50     std::string data;
51     Socket::SocketRemoteInfo remoteInfo;
52 };
53 using OnMessageCallback =
54     std::function<void(const int &socketFd, const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
55 using OnCloseCallback = std::function<void(const int &socketFd)>;
56 using OnConnectCallback = std::function<void(const int &socketFd, std::shared_ptr<EventManager> eventManager)>;
57 using ListenCallback = std::function<void(int32_t errorNumber)>;
58 class TLSServerSendOptions {
59 public:
60     /**
61      * Set the socket ID to be transmitted
62      * @param socketFd Communication descriptor
63      */
64     void SetSocket(const int &socketFd);
65 
66     /**
67      * Set the data to send
68      * @param data Send data
69      */
70     void SetSendData(const std::string &data);
71 
72     /**
73      * Get the socket ID
74      * @return Gets the communication descriptor
75      */
76     [[nodiscard]] const int &GetSocket() const;
77 
78     /**
79      * Gets the data sent
80      * @return Send data
81      */
82     [[nodiscard]] const std::string &GetSendData() const;
83 
84 private:
85     int socketFd_;
86     std::string data_;
87 };
88 
89 class TLSSocketServer {
90 public:
91     TLSSocketServer(const TLSSocketServer &) = delete;
92     TLSSocketServer(TLSSocketServer &&) = delete;
93 
94     TLSSocketServer &operator=(const TLSSocketServer &) = delete;
95     TLSSocketServer &operator=(TLSSocketServer &&) = delete;
96 
97     TLSSocketServer() = default;
98     ~TLSSocketServer();
99 
100     /**
101      * Create sockets, bind and listen waiting for clients to connect
102      * @param tlsListenOptions Bind the listening connection configuration
103      * @param callback callback to the caller if bind ok or not
104      */
105     void Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback);
106 
107     /**
108      * Send data through an established encrypted connection
109      * @param data data sent over an established encrypted connection
110      * @return whether the data is successfully sent to the server
111      */
112     bool Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback);
113 
114     /**
115      * Disconnect by releasing the socket when communicating
116      * @param socketFd The socket ID of the client
117      * @param callback callback to the caller
118      */
119     void Close(const int socketFd, const TlsSocket::CloseCallback &callback);
120 
121     /**
122      * Disconnect by releasing the socket when communicating
123      * @param callback callback to the caller
124      */
125     void Stop(const TlsSocket::CloseCallback &callback);
126 
127     /**
128      * Get the peer network address
129      * @param socketFd The socket ID of the client
130      * @param callback callback to the caller
131      */
132     void GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback);
133 
134     /**
135      * Get the peer network address
136      * @param socketFd The socket ID of the client
137      * @param callback callback to the caller
138      */
139     void GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback);
140 
141     /**
142      * Get the status of the current socket
143      * @param callback callback to the caller
144      */
145     void GetState(const TlsSocket::GetStateCallback &callback);
146 
147     /**
148      * Gets or sets the options associated with the current socket
149      * @param tcpExtraOptions options associated with the current socket
150      * @param callback callback to the caller
151      */
152     bool SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
153                          const TlsSocket::SetExtraOptionsCallback &callback);
154 
155     /**
156      *  Get a local digital certificate
157      * @param callback callback to the caller
158      */
159     void GetCertificate(const TlsSocket::GetCertificateCallback &callback);
160 
161     /**
162      * Get the peer digital certificate
163      * @param socketFd The socket ID of the client
164      * @param needChain need chain
165      * @param callback callback to the caller
166      */
167     void GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback);
168 
169     /**
170      * Obtain the protocol used in communication
171      * @param callback callback to the caller
172      */
173     void GetProtocol(const TlsSocket::GetProtocolCallback &callback);
174 
175     /**
176      * Obtain the cipher suite used in communication
177      * @param socketFd The socket ID of the client
178      * @param callback callback to the caller
179      */
180     void GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback);
181 
182     /**
183      * Obtain the encryption algorithm used in the communication process
184      * @param socketFd The socket ID of the client
185      * @param callback callback to the caller
186      */
187     void GetSignatureAlgorithms(const int socketFd, const TlsSocket::GetSignatureAlgorithmsCallback &callback);
188 
189     /**
190      * Register the callback that is called when the connection is disconnected
191      * @param onCloseCallback callback invoked when disconnected
192      */
193 
194     /**
195      * Register the callback that is called when the connection is established
196      * @param onConnectCallback callback invoked when connection is established
197      */
198     void OnConnect(const OnConnectCallback &onConnectCallback);
199 
200     /**
201      * Register the callback that is called when an error occurs
202      * @param onErrorCallback callback invoked when an error occurs
203      */
204     void OnError(const TlsSocket::OnErrorCallback &onErrorCallback);
205 
206     /**
207      * Off Connect
208      */
209     void OffConnect();
210 
211     /**
212      * Off Error
213      */
214     void OffError();
215 
216     /**
217      * Get the socket file description of the server
218      */
219     int GetListenSocketFd();
220 
221     /**
222      * Set the current socket file description address of the server
223      */
224     void SetLocalAddress(const Socket::NetAddress &address);
225 
226     /**
227      * Get the current socket file description address of the server
228      */
229     Socket::NetAddress GetLocalAddress();
230 
231 public:
232     class Connection : public std::enable_shared_from_this<Connection> {
233     public:
234         ~Connection();
235         /**
236          * Establish an encrypted accept on the specified socket
237          * @param sock socket for establishing encrypted connection
238          * @param options some options required during tls accept
239          * @return whether the encrypted accept is successfully established
240          */
241         bool TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options);
242 
243         /**
244          * Set the configuration items for establishing encrypted connections
245          * @param config configuration item when establishing encrypted connection
246          */
247         void SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config);
248 
249         /**
250          * Set address information
251          */
252         void SetAddress(const Socket::NetAddress address);
253 
254         /**
255          * Set local address information
256          */
257         void SetLocalAddress(const Socket::NetAddress address);
258 
259         /**
260          * Send data through an established encrypted connection
261          * @param data data sent over an established encrypted connection
262          * @return whether the data is successfully sent to the server
263          */
264         bool Send(const std::string &data);
265 
266         /**
267          * Receive the data sent by the server through the established encrypted connection
268          * @param buffer receive the data sent by the server
269          * @param maxBufferSize the size of the data received from the server
270          * @return whether the data sent by the server is successfully received
271          */
272         int Recv(char *buffer, int maxBufferSize);
273 
274         /**
275          * Disconnect encrypted connection
276          * @return whether the encrypted connection was successfully disconnected
277          */
278         bool Close();
279 
280         /**
281          * Set the application layer negotiation protocol in the encrypted communication process
282          * @param alpnProtocols application layer negotiation protocol
283          * @return set whether the application layer negotiation protocol is successful during encrypted communication
284          */
285         bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
286 
287         /**
288          * Storage of server communication related network information
289          * @param remoteInfo communication related network information
290          */
291         void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
292 
293         /**
294          * Get configuration options for encrypted communication process
295          * @return configuration options for encrypted communication processes
296          */
297         [[nodiscard]] TlsSocket::TLSConfiguration GetTlsConfiguration() const;
298 
299         /**
300          * Obtain the cipher suite during encrypted communication
301          * @return crypto suite used in encrypted communication
302          */
303         [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
304 
305         /**
306          * Obtain the peer certificate used in encrypted communication
307          * @return peer certificate used in encrypted communication
308          */
309         [[nodiscard]] std::string GetRemoteCertificate() const;
310 
311         /**
312          * Obtain the peer certificate used in encrypted communication
313          * @return peer certificate serialization data used in encrypted communication
314          */
315         [[nodiscard]] const TlsSocket::X509CertRawData &GetRemoteCertRawData() const;
316 
317         /**
318          * Obtain the certificate used in encrypted communication
319          * @return certificate serialization data used in encrypted communication
320          */
321         [[nodiscard]] const TlsSocket::X509CertRawData &GetCertificate() const;
322 
323         /**
324          * Get the encryption algorithm used in encrypted communication
325          * @return encryption algorithm used in encrypted communication
326          */
327         [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
328 
329         /**
330          * Obtain the communication protocol used in encrypted communication
331          * @return communication protocol used in encrypted communication
332          */
333         [[nodiscard]] std::string GetProtocol() const;
334 
335         /**
336          * Set the information about the shared signature algorithm supported by peers during encrypted communication
337          * @return information about peer supported shared signature algorithms
338          */
339         [[nodiscard]] bool SetSharedSigals();
340 
341         /**
342          * Obtain the ssl used in encrypted communication
343          * @return SSL used in encrypted communication
344          */
345         [[nodiscard]] ssl_st *GetSSL() const;
346 
347         /**
348          * Get address information
349          * @return Returns the address information of the remote client
350          */
351         [[nodiscard]] Socket::NetAddress GetAddress() const;
352 
353         /**
354          * Get local address information
355          * @return Returns the address information of the local accept connect
356          */
357         [[nodiscard]] Socket::NetAddress GetLocalAddress() const;
358 
359         /**
360          * Get address information
361          * @return Returns the address information of the remote client
362          */
363         [[nodiscard]] int GetSocketFd() const;
364 
365         /**
366          * Get EventManager information
367          * @return Returns the address information of the remote client
368          */
369         [[nodiscard]] std::shared_ptr<EventManager> GetEventManager() const;
370 
371         void OnMessage(const OnMessageCallback &onMessageCallback);
372         /**
373          * Unregister the callback which is called when message is received
374          */
375         void OffMessage();
376 
377         void CallOnMessageCallback(int32_t socketFd, const std::string &data,
378                                    const Socket::SocketRemoteInfo &remoteInfo);
379 
380         void SetEventManager(std::shared_ptr<EventManager> eventManager);
381 
382         void SetClientID(int32_t clientID);
383 
384         [[nodiscard]] int GetClientID();
385 
386         void CallOnCloseCallback(const int32_t socketFd);
387         void OnClose(const OnCloseCallback &onCloseCallback);
388         OnCloseCallback onCloseCallback_;
389 
390         /**
391          * Off Close
392          */
393         void OffClose();
394 
395         /**
396          * Register the callback that is called when an error occurs
397          * @param onErrorCallback callback invoked when an error occurs
398          */
399         void OnError(const TlsSocket::OnErrorCallback &onErrorCallback);
400         /**
401          * Off Error
402          */
403         void OffError();
404 
405         void CallOnErrorCallback(int32_t err, const std::string &errString);
406 
407         class DataCache {
408         public:
Get()409             CacheInfo Get()
410             {
411                 std::lock_guard l(mutex_);
412                 CacheInfo cache = cacheDeque_.front();
413                 cacheDeque_.pop_front();
414                 return cache;
415             }
Set(const CacheInfo & data)416             void Set(const CacheInfo &data)
417             {
418                 std::lock_guard l(mutex_);
419                 cacheDeque_.emplace_back(data);
420             }
IsEmpty()421             bool IsEmpty()
422             {
423                 std::lock_guard l(mutex_);
424                 return cacheDeque_.empty();
425             }
426 
427         private:
428             std::deque<CacheInfo> cacheDeque_;
429             std::mutex mutex_;
430         };
431 
432         TlsSocket::OnErrorCallback onErrorCallback_;
433 
434     private:
435         bool StartTlsAccept(const TlsSocket::TLSConnectOptions &options);
436         bool CreatTlsContext();
437         bool StartShakingHands(const TlsSocket::TLSConnectOptions &options);
438         bool GetRemoteCertificateFromPeer();
439         bool SetRemoteCertRawData();
440         std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
441         std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
442                                              const X509 *x509Certificates);
443 
444     private:
445         ssl_st *ssl_ = nullptr;
446         X509 *peerX509_ = nullptr;
447         int32_t socketFd_ = 0;
448 
449         TlsSocket::TLSContextServer tlsContext_;
450         TlsSocket::TLSConfiguration connectionConfiguration_;
451         Socket::NetAddress address_;
452         Socket::NetAddress localAddress_;
453         TlsSocket::X509CertRawData remoteRawData_;
454 
455         std::string hostName_;
456         std::string remoteCert_;
457         std::string keyPass_;
458 
459         std::vector<std::string> signatureAlgorithms_;
460         std::unique_ptr<TlsSocket::TLSContextServer> tlsContextServerPointer_ = nullptr;
461 
462         std::shared_ptr<EventManager> eventManager_ = nullptr;
463         int32_t clientID_ = 0;
464         OnMessageCallback onMessageCallback_;
465         std::shared_ptr<DataCache> dataCache_ = std::make_shared<DataCache>();
466     };
467 
468 private:
469     void SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config);
470     int RecvRemoteInfo(int socketFd, int index);
471     void RemoveConnect(int socketFd);
472     void AddConnect(int socketFd, std::shared_ptr<Connection> connection);
473     void CallListenCallback(int32_t err, ListenCallback callback);
474     void CallOnErrorCallback(int32_t err, const std::string &errString);
475 
476     void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, TlsSocket::GetStateCallback callback);
477     void CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager);
478     void CallSendCallback(int32_t err, TlsSocket::SendCallback callback);
479     bool ExecBind(const Socket::NetAddress &address, const ListenCallback &callback);
480     void ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback);
481     void MakeIpSocket(sa_family_t family);
482     void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
483                  socklen_t *len);
484     static constexpr const size_t MAX_ERROR_LEN = 128;
485     static constexpr const size_t MAX_BUFFER_SIZE = 8192;
486 
487     void PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions);
488 
489 private:
490     std::mutex mutex_;
491     std::mutex connectMutex_;
492     int listenSocketFd_ = -1;
493     Socket::NetAddress address_;
494     Socket::NetAddress localAddress_;
495 
496     std::map<int, std::shared_ptr<Connection>> clientIdConnections_;
497     TlsSocket::TLSConfiguration TLSServerConfiguration_;
498 
499     OnConnectCallback onConnectCallback_;
500     TlsSocket::OnErrorCallback onErrorCallback_;
501 
502     bool GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress);
503     void ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientId);
504     void DropFdFromPollList(int &fd_index);
505     void InitPollList(int &listendFd);
506 
507     struct pollfd fds_[USER_LIMIT + 1];
508 
509     bool isRunning_;
510 
511 public:
512     std::shared_ptr<Connection> GetConnectionByClientID(int clientid);
513     int GetConnectionClientCount();
514 
515     std::shared_ptr<Connection> GetConnectionByClientEventManager(const EventManager *eventManager);
516     void CloseConnectionByEventManager(EventManager *eventManager);
517     void DeleteConnectionByEventManager(EventManager *eventManager);
518     void SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID,
519                                        int connectFD, std::shared_ptr<Connection> &connection);
520 };
521 } // namespace TlsSocketServer
522 } // namespace NetStack
523 } // namespace OHOS
524 
525 #endif // COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
526