1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H
17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H
18 
19 #include <any>
20 #include <condition_variable>
21 #include <cstring>
22 #include <functional>
23 #include <map>
24 #include <thread>
25 #include <tuple>
26 #include <unistd.h>
27 #include <vector>
28 
29 #include "extra_options_base.h"
30 #include "net_address.h"
31 #include "socket_error.h"
32 #include "socket_remote_info.h"
33 #include "socket_state_base.h"
34 #include "tcp_connect_options.h"
35 #include "tcp_extra_options.h"
36 #include "tcp_send_options.h"
37 #include "tls.h"
38 #include "tls_certificate.h"
39 #include "tls_configuration.h"
40 #include "tls_context.h"
41 #include "tls_key.h"
42 
43 namespace OHOS {
44 namespace NetStack {
45 namespace TlsSocket {
46 
47 using BindCallback = std::function<void(int32_t errorNumber)>;
48 using ConnectCallback = std::function<void(int32_t errorNumber)>;
49 using SendCallback = std::function<void(int32_t errorNumber)>;
50 using CloseCallback = std::function<void(int32_t errorNumber)>;
51 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
52 using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
53 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>;
54 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>;
55 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
56 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
57 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>;
58 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>;
59 using GetSignatureAlgorithmsCallback =
60     std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>;
61 
62 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
63 using OnConnectCallback = std::function<void(void)>;
64 using OnCloseCallback = std::function<void(void)>;
65 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>;
66 
67 using CheckServerIdentity =
68     std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>;
69 
70 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1";
71 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2";
72 
73 constexpr size_t MAX_ERR_LEN = 1024;
74 
75 /**
76  * Parameters required during communication
77  */
78 class TLSSecureOptions {
79 public:
80     TLSSecureOptions() = default;
81     ~TLSSecureOptions() = default;
82 
83     TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions);
84     TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions);
85     /**
86      * Set root CA Chain to verify the server cert
87      * @param caChain root certificate chain used to validate server certificates
88      */
89     void SetCaChain(const std::vector<std::string> &caChain);
90 
91     /**
92      * Set digital certificate for server verification
93      * @param cert digital certificate sent to the server to verify validity
94      */
95     void SetCert(const std::string &cert);
96 
97     /**
98      * Set key to decrypt server data
99      * @param keyChain key used to decrypt server data
100      */
101     void SetKey(const SecureData &key);
102 
103     /**
104      * Set the password to read the private key
105      * @param keyPass read the password of the private key
106      */
107     void SetKeyPass(const SecureData &keyPass);
108 
109     /**
110      * Set the protocol used in communication
111      * @param protocolChain protocol version number used
112      */
113     void SetProtocolChain(const std::vector<std::string> &protocolChain);
114 
115     /**
116      * Whether the peer cipher suite is preferred for communication
117      * @param useRemoteCipherPrefer whether the peer cipher suite is preferred
118      */
119     void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer);
120 
121     /**
122      * Encryption algorithm used in communication
123      * @param signatureAlgorithms encryption algorithm e.g: rsa
124      */
125     void SetSignatureAlgorithms(const std::string &signatureAlgorithms);
126 
127     /**
128      * Crypto suite used in communication
129      * @param cipherSuite cipher suite e.g:AES256-SHA256
130      */
131     void SetCipherSuite(const std::string &cipherSuite);
132 
133     /**
134      * Set a revoked certificate
135      * @param crlChain certificate Revocation List
136      */
137     void SetCrlChain(const std::vector<std::string> &crlChain);
138 
139     /**
140      * Get root CA Chain to verify the server cert
141      * @return root CA chain
142      */
143     [[nodiscard]] const std::vector<std::string> &GetCaChain() const;
144 
145     /**
146      * Obtain a certificate to send to the server for checking
147      * @return digital certificate obtained
148      */
149     [[nodiscard]] const std::string &GetCert() const;
150 
151     /**
152      * Obtain the private key in the communication process
153      * @return private key during communication
154      */
155     [[nodiscard]] const SecureData &GetKey() const;
156 
157     /**
158      * Get the password to read the private key
159      * @return read the password of the private key
160      */
161     [[nodiscard]] const SecureData &GetKeyPass() const;
162 
163     /**
164      * Get the protocol of the communication process
165      * @return protocol of communication process
166      */
167     [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const;
168 
169     /**
170      * Is the remote cipher suite being used for communication
171      * @return is use Remote Cipher Prefer
172      */
173     [[nodiscard]] bool UseRemoteCipherPrefer() const;
174 
175     /**
176      * Obtain the encryption algorithm used in the communication process
177      * @return encryption algorithm used in communication
178      */
179     [[nodiscard]] const std::string &GetSignatureAlgorithms() const;
180 
181     /**
182      * Obtain the cipher suite used in communication
183      * @return crypto suite used in communication
184      */
185     [[nodiscard]] const std::string &GetCipherSuite() const;
186 
187     /**
188      * Get revoked certificate chain
189      * @return revoked certificate chain
190      */
191     [[nodiscard]] const std::vector<std::string> &GetCrlChain() const;
192 
193     void SetVerifyMode(VerifyMode verifyMode);
194 
195     [[nodiscard]] VerifyMode GetVerifyMode() const;
196 
197 private:
198     std::vector<std::string> caChain_;
199     std::string cert_;
200     SecureData key_;
201     SecureData keyPass_;
202     std::vector<std::string> protocolChain_;
203     bool useRemoteCipherPrefer_ = false;
204     std::string signatureAlgorithms_;
205     std::string cipherSuite_;
206     std::vector<std::string> crlChain_;
207     VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE;
208 };
209 
210 /**
211  * Some options required during tls connection
212  */
213 class TLSConnectOptions {
214 public:
215     friend class TLSSocketExec;
216     /**
217      * Communication parameters required for connection establishment
218      * @param address communication parameters during connection
219      */
220     void SetNetAddress(const Socket::NetAddress &address);
221 
222     /**
223      * Parameters required during communication
224      * @param tlsSecureOptions certificate and other relevant parameters
225      */
226     void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions);
227 
228     /**
229      * Set the callback function to check the validity of the server
230      * @param checkServerIdentity callback function passed in by API caller
231      */
232     void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity);
233 
234     /**
235      * Set application layer protocol negotiation
236      * @param alpnProtocols application layer protocol negotiation
237      */
238     void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
239 
240     /**
241      * Set whether to skip remote validation
242      * @param skipRemoteValidation flag to choose whether to skip validation
243      */
244     void SetSkipRemoteValidation(bool skipRemoteValidation);
245 
246     /**
247      * Obtain the network address of the communication process
248      * @return network address
249      */
250     [[nodiscard]] Socket::NetAddress GetNetAddress() const;
251 
252     /**
253      * Obtain the parameters required in the communication process
254      * @return certificate and other relevant parameters
255      */
256     [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const;
257 
258     /**
259      * Get the check server ID callback function passed in by the API caller
260      * @return check the server identity callback function
261      */
262     [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const;
263 
264     /**
265      * Obtain the application layer protocol negotiation in the communication process
266      * @return application layer protocol negotiation
267      */
268     [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const;
269 
270     /**
271      * Get the choice of whether to skip remote validaion
272      * @return skipRemoteValidaion result
273      */
274     [[nodiscard]] bool GetSkipRemoteValidation() const;
275 
276     void SetHostName(const std::string &hostName);
277     [[nodiscard]] std::string GetHostName() const;
278 
279 private:
280     Socket::NetAddress address_;
281     TLSSecureOptions tlsSecureOptions_;
282     CheckServerIdentity checkServerIdentity_;
283     std::vector<std::string> alpnProtocols_;
284     bool skipRemoteValidation_ = false;
285     std::string hostName_;
286 };
287 
288 /**
289  * TLS socket interface class
290  */
291 class TLSSocket {
292 public:
293     TLSSocket(const TLSSocket &) = delete;
294     TLSSocket(TLSSocket &&) = delete;
295 
296     TLSSocket &operator=(const TLSSocket &) = delete;
297     TLSSocket &operator=(TLSSocket &&) = delete;
298 
299     TLSSocket() = default;
300     ~TLSSocket() = default;
301 
TLSSocket(int sockFd)302     explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {}
303 
304     /**
305      * Create a socket and bind to the address specified by address
306      * @param address ip address
307      * @param callback callback to the caller if bind ok or not
308      */
309     void Bind(Socket::NetAddress &address, const BindCallback &callback);
310 
311     /**
312      * Establish a secure connection based on the created socket
313      * @param tlsConnectOptions some options required during tls connection
314      * @param callback callback to the caller if connect ok or not
315      */
316     void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback);
317 
318     /**
319      * Send data based on the created socket
320      * @param tcpSendOptions  some options required during tcp data transmission
321      * @param callback callback to the caller if send ok or not
322      */
323     void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback);
324 
325     /**
326      * Disconnect by releasing the socket when communicating
327      * @param callback callback to the caller
328      */
329     void Close(const CloseCallback &callback);
330 
331     /**
332      * Get the peer network address
333      * @param callback callback to the caller
334      */
335     void GetRemoteAddress(const GetRemoteAddressCallback &callback);
336 
337     /**
338      * Get the status of the current socket
339      * @param callback callback to the caller
340      */
341     void GetState(const GetStateCallback &callback);
342 
343     /**
344      * Gets or sets the options associated with the current socket
345      * @param tcpExtraOptions options associated with the current socket
346      * @param callback callback to the caller
347      */
348     void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback);
349 
350     /**
351      *  Get a local digital certificate
352      * @param callback callback to the caller
353      */
354     void GetCertificate(const GetCertificateCallback &callback);
355 
356     /**
357      * Get the peer digital certificate
358      * @param needChain need chain
359      * @param callback callback to the caller
360      */
361     void GetRemoteCertificate(const GetRemoteCertificateCallback &callback);
362 
363     /**
364      * Obtain the protocol used in communication
365      * @param callback callback to the caller
366      */
367     void GetProtocol(const GetProtocolCallback &callback);
368 
369     /**
370      * Obtain the cipher suite used in communication
371      * @param callback callback to the caller
372      */
373     void GetCipherSuite(const GetCipherSuiteCallback &callback);
374 
375     /**
376      * Obtain the encryption algorithm used in the communication process
377      * @param callback callback to the caller
378      */
379     void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback);
380 
381     /**
382      * Register a callback which is called when message is received
383      * @param onMessageCallback callback which is called when message is received
384      */
385     void OnMessage(const OnMessageCallback &onMessageCallback);
386 
387     /**
388      * Register the callback that is called when the connection is established
389      * @param onConnectCallback callback invoked when connection is established
390      */
391     void OnConnect(const OnConnectCallback &onConnectCallback);
392 
393     /**
394      * Register the callback that is called when the connection is disconnected
395      * @param onCloseCallback callback invoked when disconnected
396      */
397     void OnClose(const OnCloseCallback &onCloseCallback);
398 
399     /**
400      * Register the callback that is called when an error occurs
401      * @param onErrorCallback callback invoked when an error occurs
402      */
403     void OnError(const OnErrorCallback &onErrorCallback);
404 
405     /**
406      * Unregister the callback which is called when message is received
407      */
408     void OffMessage();
409 
410     /**
411      * Off Connect
412      */
413     void OffConnect();
414 
415     /**
416      * Off Close
417      */
418     void OffClose();
419 
420     /**
421      * Off Error
422      */
423     void OffError();
424 
425     /**
426      * Get the socket file description of the server
427      */
428     int GetSocketFd();
429 
430     /**
431      * Set the current socket file description address of the server
432      */
433     void SetLocalAddress(const Socket::NetAddress &address);
434 
435     /**
436      * Get the current socket file description address of the server
437      */
438     Socket::NetAddress GetLocalAddress();
439 
440     bool GetCloseState();
441 
442     void SetCloseState(bool flag);
443 
444     std::mutex &GetCloseLock();
445 private:
446     class TLSSocketInternal final {
447     public:
448         TLSSocketInternal() = default;
449         ~TLSSocketInternal() = default;
450 
451         /**
452          * Establish an encrypted connection on the specified socket
453          * @param sock socket for establishing encrypted connection
454          * @param options some options required during tls connection
455          * @param isExtSock socket fd is originated from external source when constructing tls socket
456          * @return whether the encrypted connection is successfully established
457          */
458         bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock);
459 
460         /**
461          * Set the configuration items for establishing encrypted connections
462          * @param config configuration item when establishing encrypted connection
463          */
464         void SetTlsConfiguration(const TLSConnectOptions &config);
465 
466         /**
467          * Send data through an established encrypted connection
468          * @param data data sent over an established encrypted connection
469          * @return whether the data is successfully sent to the server
470          */
471         bool Send(const std::string &data);
472 
473         /**
474          * Receive the data sent by the server through the established encrypted connection
475          * @param buffer receive the data sent by the server
476          * @param maxBufferSize the size of the data received from the server
477          * @return whether the data sent by the server is successfully received
478          */
479         int Recv(char *buffer, int maxBufferSize);
480 
481         /**
482          * Disconnect encrypted connection
483          * @return whether the encrypted connection was successfully disconnected
484          */
485         bool Close();
486 
487         /**
488          * Set the application layer negotiation protocol in the encrypted communication process
489          * @param alpnProtocols application layer negotiation protocol
490          * @return set whether the application layer negotiation protocol is successful during encrypted communication
491          */
492         bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
493 
494         /**
495          * Storage of server communication related network information
496          * @param remoteInfo communication related network information
497          */
498         void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
499 
500         /**
501          * convert the code to ssl error code
502          * @return the value for ssl error code.
503          */
504         int ConvertSSLError(void);
505 
506         /**
507          * Get configuration options for encrypted communication process
508          * @return configuration options for encrypted communication processes
509          */
510         [[nodiscard]] TLSConfiguration GetTlsConfiguration() const;
511 
512         /**
513          * Obtain the cipher suite during encrypted communication
514          * @return crypto suite used in encrypted communication
515          */
516         [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
517 
518         /**
519          * Obtain the peer certificate used in encrypted communication
520          * @return peer certificate used in encrypted communication
521          */
522         [[nodiscard]] std::string GetRemoteCertificate() const;
523 
524         /**
525          * Obtain the peer certificate used in encrypted communication
526          * @return peer certificate serialization data used in encrypted communication
527          */
528         [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const;
529 
530         /**
531          * Obtain the certificate used in encrypted communication
532          * @return certificate serialization data used in encrypted communication
533          */
534         [[nodiscard]] const X509CertRawData &GetCertificate() const;
535 
536         /**
537          * Get the encryption algorithm used in encrypted communication
538          * @return encryption algorithm used in encrypted communication
539          */
540         [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
541 
542         /**
543          * Obtain the communication protocol used in encrypted communication
544          * @return communication protocol used in encrypted communication
545          */
546         [[nodiscard]] std::string GetProtocol() const;
547 
548         /**
549          * Set the information about the shared signature algorithm supported by peers during encrypted communication
550          * @return information about peer supported shared signature algorithms
551          */
552         [[nodiscard]] bool SetSharedSigals();
553 
554         /**
555          * Obtain the ssl used in encrypted communication
556          * @return SSL used in encrypted communication
557          */
558         [[nodiscard]] ssl_st *GetSSL();
559 
560     private:
561         bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd);
562         bool StartTlsConnected(const TLSConnectOptions &options);
563         bool CreatTlsContext();
564         bool StartShakingHands(const TLSConnectOptions &options);
565         bool GetRemoteCertificateFromPeer();
566         bool SetRemoteCertRawData();
567         bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize);
568         std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
569         std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
570                                              const X509 *x509Certificates);
571 
572     private:
573         std::mutex mutexForSsl_;
574         ssl_st *ssl_ = nullptr;
575         X509 *peerX509_ = nullptr;
576         uint16_t port_ = 0;
577         sa_family_t family_ = 0;
578         int32_t socketDescriptor_ = 0;
579 
580         TLSContext tlsContext_;
581         TLSConfiguration configuration_;
582         Socket::NetAddress address_;
583         X509CertRawData remoteRawData_;
584 
585         std::string hostName_;
586         std::string remoteCert_;
587 
588         std::vector<std::string> signatureAlgorithms_;
589         std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr;
590     };
591 
592 private:
593     TLSSocketInternal tlsSocketInternal_;
594 
595     static std::string MakeAddressString(sockaddr *addr);
596 
597     static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
598                         socklen_t *len);
599 
600     void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo);
601     void CallOnConnectCallback();
602     void CallOnCloseCallback();
603     void CallOnErrorCallback(int32_t err, const std::string &errString);
604 
605     void CallBindCallback(int32_t err, BindCallback callback);
606     void CallConnectCallback(int32_t err, ConnectCallback callback);
607     void CallSendCallback(int32_t err, SendCallback callback);
608     void CallCloseCallback(int32_t err, CloseCallback callback);
609     void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
610                                       GetRemoteAddressCallback callback);
611     void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback);
612     void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback);
613     void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback);
614     void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
615                                           GetRemoteCertificateCallback callback);
616     void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback);
617     void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
618                                     GetCipherSuiteCallback callback);
619     void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
620                                             GetSignatureAlgorithmsCallback callback);
621 
622     int ReadMessage();
623     void StartReadMessage();
624 
625     void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback);
626     void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback);
627 
628     [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const;
629     [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const;
630 
631     void MakeIpSocket(sa_family_t family);
632 
633     template<class T>
DealCallback(int32_t err,T & callback)634     void DealCallback(int32_t err, T &callback)
635     {
636         T func = nullptr;
637         {
638             std::lock_guard<std::mutex> lock(mutex_);
639             if (callback) {
640                 func = callback;
641             }
642         }
643 
644         if (func) {
645             func(err);
646         }
647     }
648 
649 private:
650     static constexpr const size_t MAX_ERROR_LEN = 128;
651     static constexpr const size_t MAX_BUFFER_SIZE = 8192;
652 
653     OnMessageCallback onMessageCallback_;
654     OnConnectCallback onConnectCallback_;
655     OnCloseCallback onCloseCallback_;
656     OnErrorCallback onErrorCallback_;
657 
658     std::mutex mutex_;
659     std::mutex recvMutex_;
660     std::mutex cvMutex_;
661     bool isRunning_ = false;
662     bool isRunOver_ = true;
663     std::condition_variable cvSslFree_;
664     int sockFd_ = -1;
665     bool isExtSock_ = false;
666     Socket::NetAddress localAddress_;
667     bool isClosed = false;
668     std::mutex mutexForClose_;
669 };
670 } // namespace TlsSocket
671 } // namespace NetStack
672 } // namespace OHOS
673 
674 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H
675