1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <regex>
22 #include <securec.h>
23 #include <set>
24 #include <thread>
25 #include <poll.h>
26 
27 #include <netinet/tcp.h>
28 #include <openssl/err.h>
29 #include <openssl/ssl.h>
30 
31 #include "base_context.h"
32 #include "netstack_common_utils.h"
33 #include "netstack_log.h"
34 #include "tls.h"
35 #include "socket_exec_common.h"
36 
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 constexpr int READ_TIMEOUT_MS = 500;
42 constexpr int REMOTE_CERT_LEN = 8192;
43 constexpr int COMMON_NAME_BUF_SIZE = 256;
44 constexpr int BUF_SIZE = 2048;
45 constexpr int SSL_RET_CODE = 0;
46 constexpr int SSL_ERROR_RETURN = -1;
47 constexpr int SSL_WANT_READ_RETURN = -2;
48 constexpr int OFFSET = 2;
49 constexpr int DEFAULT_BUFFER_SIZE = 8192;
50 constexpr int DEFAULT_POLL_TIMEOUT_MS = 500;
51 constexpr int SEND_RETRY_TIMES = 5;
52 constexpr int SEND_POLL_TIMEOUT_MS = 1000;
53 constexpr int MAX_RECV_BUFFER_SIZE = 1024 * 16;
54 constexpr const char *SPLIT_ALT_NAMES = ",";
55 constexpr const char *SPLIT_HOST_NAME = ".";
56 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
57 constexpr const char *UNKNOW_REASON = "Unknown reason";
58 constexpr const char *IP = "IP: ";
59 constexpr const char *HOST_NAME = "hostname: ";
60 constexpr const char *DNS = "DNS:";
61 constexpr const char *IP_ADDRESS = "IP Address:";
62 constexpr const char *SIGN_NID_RSA = "RSA+";
63 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
64 constexpr const char *SIGN_NID_DSA = "DSA+";
65 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
66 constexpr const char *SIGN_NID_ED = "Ed25519+";
67 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
68 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
69 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
70 constexpr const char *OPERATOR_PLUS_SIGN = "+";
71 static constexpr const char *TLS_SOCKET_CLIENT_READ = "OS_NET_TSCliRD";
72 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
73 const std::regex PATTERN{
74     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
75     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
76 
77 class CaCertCache {
78 public:
GetInstance()79     static CaCertCache &GetInstance()
80     {
81         static CaCertCache instance;
82         return instance;
83     }
84 
Get(const std::string & key)85     std::set<std::string> Get(const std::string &key)
86     {
87         std::lock_guard l(mutex_);
88         auto it = map_.find(key);
89         if (it != map_.end()) {
90             return it->second;
91         }
92         return {};
93     }
94 
Set(const std::string & key,const std::string & val)95     void Set(const std::string &key, const std::string &val)
96     {
97         std::lock_guard l(mutex_);
98         map_[key].insert(val);
99     }
100 
101 private:
102     CaCertCache() = default;
103     ~CaCertCache() = default;
104     CaCertCache &operator=(const CaCertCache &) = delete;
105     CaCertCache(const CaCertCache &) = delete;
106 
107     std::map<std::string, std::set<std::string>> map_;
108     std::mutex mutex_;
109 };
110 
ConvertErrno()111 int ConvertErrno()
112 {
113     return TlsSocketError::TLS_ERR_SYS_BASE + errno;
114 }
115 
MakeErrnoString()116 std::string MakeErrnoString()
117 {
118     return strerror(errno);
119 }
120 
MakeSSLErrorString(int error)121 std::string MakeSSLErrorString(int error)
122 {
123     char err[MAX_ERR_LEN] = {0};
124     ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
125     return err;
126 }
127 
SplitEscapedAltNames(std::string & altNames)128 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
129 {
130     std::vector<std::string> result;
131     std::string currentToken;
132     size_t offset = 0;
133     while (offset != altNames.length()) {
134         auto nextSep = altNames.find_first_of(", ");
135         auto nextQuote = altNames.find_first_of('\"');
136         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
137             currentToken += altNames.substr(offset, nextQuote);
138             std::regex jsonStringPattern(JSON_STRING_PATTERN);
139             std::smatch match;
140             std::string altNameSubStr = altNames.substr(nextQuote);
141             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
142             if (!ret) {
143                 return {""};
144             }
145             currentToken += result[0];
146             offset = nextQuote + result[0].length();
147         } else if (nextSep != std::string::npos) {
148             currentToken += altNames.substr(offset, nextSep);
149             result.push_back(currentToken);
150             currentToken = "";
151             offset = nextSep + OFFSET;
152         } else {
153             currentToken += altNames.substr(offset);
154             offset = altNames.length();
155         }
156     }
157     result.push_back(currentToken);
158     return result;
159 }
160 
IsIP(const std::string & ip)161 bool IsIP(const std::string &ip)
162 {
163     std::regex pattern(PATTERN);
164     std::smatch res;
165     return regex_match(ip, res, pattern);
166 }
167 
SplitHostName(std::string & hostName)168 std::vector<std::string> SplitHostName(std::string &hostName)
169 {
170     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
171     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
172 }
173 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)174 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
175 {
176     std::vector<std::string> result;
177     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
178     return !result.empty();
179 }
180 } // namespace
181 
SetSockBlockFlag(int sock,bool noneBlock)182 static bool SetSockBlockFlag(int sock, bool noneBlock)
183 {
184     int flags = fcntl(sock, F_GETFL, 0);
185     while (flags == -1 && errno == EINTR) {
186         flags = fcntl(sock, F_GETFL, 0);
187     }
188     if (flags == -1) {
189         NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
190         return false;
191     }
192 
193     auto newFlags = static_cast<size_t>(flags);
194     if (noneBlock) {
195         newFlags |= static_cast<size_t>(O_NONBLOCK);
196     } else {
197         newFlags &= ~static_cast<size_t>(O_NONBLOCK);
198     }
199 
200     int ret = fcntl(sock, F_SETFL, newFlags);
201     while (ret == -1 && errno == EINTR) {
202         ret = fcntl(sock, F_SETFL, newFlags);
203     }
204     if (ret == -1) {
205         NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
206         return false;
207     }
208     return true;
209 }
210 
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)211 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
212 {
213     *this = tlsSecureOptions;
214 }
215 
operator =(const TLSSecureOptions & tlsSecureOptions)216 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
217 {
218     key_ = tlsSecureOptions.GetKey();
219     caChain_ = tlsSecureOptions.GetCaChain();
220     cert_ = tlsSecureOptions.GetCert();
221     protocolChain_ = tlsSecureOptions.GetProtocolChain();
222     crlChain_ = tlsSecureOptions.GetCrlChain();
223     keyPass_ = tlsSecureOptions.GetKeyPass();
224     key_ = tlsSecureOptions.GetKey();
225     signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
226     cipherSuite_ = tlsSecureOptions.GetCipherSuite();
227     useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
228     TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
229     return *this;
230 }
231 
SetCaChain(const std::vector<std::string> & caChain)232 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
233 {
234     caChain_ = caChain;
235 }
236 
SetCert(const std::string & cert)237 void TLSSecureOptions::SetCert(const std::string &cert)
238 {
239     cert_ = cert;
240 }
241 
SetKey(const SecureData & key)242 void TLSSecureOptions::SetKey(const SecureData &key)
243 {
244     key_ = key;
245 }
246 
SetKeyPass(const SecureData & keyPass)247 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
248 {
249     keyPass_ = keyPass;
250 }
251 
SetProtocolChain(const std::vector<std::string> & protocolChain)252 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
253 {
254     protocolChain_ = protocolChain;
255 }
256 
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)257 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
258 {
259     useRemoteCipherPrefer_ = useRemoteCipherPrefer;
260 }
261 
SetSignatureAlgorithms(const std::string & signatureAlgorithms)262 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
263 {
264     signatureAlgorithms_ = signatureAlgorithms;
265 }
266 
SetCipherSuite(const std::string & cipherSuite)267 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
268 {
269     cipherSuite_ = cipherSuite;
270 }
271 
SetCrlChain(const std::vector<std::string> & crlChain)272 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
273 {
274     crlChain_ = crlChain;
275 }
276 
GetCaChain() const277 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
278 {
279     return caChain_;
280 }
281 
GetCert() const282 const std::string &TLSSecureOptions::GetCert() const
283 {
284     return cert_;
285 }
286 
GetKey() const287 const SecureData &TLSSecureOptions::GetKey() const
288 {
289     return key_;
290 }
291 
GetKeyPass() const292 const SecureData &TLSSecureOptions::GetKeyPass() const
293 {
294     return keyPass_;
295 }
296 
GetProtocolChain() const297 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
298 {
299     return protocolChain_;
300 }
301 
UseRemoteCipherPrefer() const302 bool TLSSecureOptions::UseRemoteCipherPrefer() const
303 {
304     return useRemoteCipherPrefer_;
305 }
306 
GetSignatureAlgorithms() const307 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
308 {
309     return signatureAlgorithms_;
310 }
311 
GetCipherSuite() const312 const std::string &TLSSecureOptions::GetCipherSuite() const
313 {
314     return cipherSuite_;
315 }
316 
GetCrlChain() const317 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
318 {
319     return crlChain_;
320 }
321 
SetVerifyMode(VerifyMode verifyMode)322 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
323 {
324     TLSVerifyMode_ = verifyMode;
325 }
326 
GetVerifyMode() const327 VerifyMode TLSSecureOptions::GetVerifyMode() const
328 {
329     return TLSVerifyMode_;
330 }
331 
SetNetAddress(const Socket::NetAddress & address)332 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
333 {
334     address_.SetFamilyBySaFamily(address.GetSaFamily());
335     address_.SetRawAddress(address.GetAddress());
336     address_.SetPort(address.GetPort());
337 }
338 
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)339 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
340 {
341     tlsSecureOptions_ = tlsSecureOptions;
342 }
343 
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)344 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
345 {
346     checkServerIdentity_ = checkServerIdentity;
347 }
348 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)349 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
350 {
351     alpnProtocols_ = alpnProtocols;
352 }
353 
SetSkipRemoteValidation(bool skipRemoteValidation)354 void TLSConnectOptions::SetSkipRemoteValidation(bool skipRemoteValidation)
355 {
356     skipRemoteValidation_ = skipRemoteValidation;
357 }
358 
GetNetAddress() const359 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
360 {
361     return address_;
362 }
363 
GetTlsSecureOptions() const364 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
365 {
366     return tlsSecureOptions_;
367 }
368 
GetCheckServerIdentity() const369 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
370 {
371     return checkServerIdentity_;
372 }
373 
GetAlpnProtocols() const374 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
375 {
376     return alpnProtocols_;
377 }
378 
GetSkipRemoteValidation() const379 bool TLSConnectOptions::GetSkipRemoteValidation() const
380 {
381     return skipRemoteValidation_;
382 }
383 
SetHostName(const std::string & hostName)384 void TLSConnectOptions::SetHostName(const std::string &hostName)
385 {
386     hostName_ = hostName;
387 }
388 
GetHostName() const389 std::string TLSConnectOptions::GetHostName() const
390 {
391     return hostName_;
392 }
393 
MakeAddressString(sockaddr * addr)394 std::string TLSSocket::MakeAddressString(sockaddr *addr)
395 {
396     if (!addr) {
397         return {};
398     }
399     if (addr->sa_family == AF_INET) {
400         auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
401         const char *str = inet_ntoa(addr4->sin_addr);
402         if (str == nullptr || strlen(str) == 0) {
403             return {};
404         }
405         return str;
406     } else if (addr->sa_family == AF_INET6) {
407         auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
408         char str[INET6_ADDRSTRLEN] = {0};
409         if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
410             return {};
411         }
412         return str;
413     }
414     return {};
415 }
416 
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)417 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
418                         socklen_t *len)
419 {
420     if (!addr6 || !addr4 || !len) {
421         return;
422     }
423     sa_family_t family = address.GetSaFamily();
424     if (family == AF_INET) {
425         addr4->sin_family = AF_INET;
426         addr4->sin_port = htons(address.GetPort());
427         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
428         *addr = reinterpret_cast<sockaddr *>(addr4);
429         *len = sizeof(sockaddr_in);
430     } else if (family == AF_INET6) {
431         addr6->sin6_family = AF_INET6;
432         addr6->sin6_port = htons(address.GetPort());
433         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
434         *addr = reinterpret_cast<sockaddr *>(addr6);
435         *len = sizeof(sockaddr_in6);
436     }
437 }
438 
MakeIpSocket(sa_family_t family)439 void TLSSocket::MakeIpSocket(sa_family_t family)
440 {
441     if (family != AF_INET && family != AF_INET6) {
442         return;
443     }
444     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
445     if (sock < 0) {
446         int resErr = ConvertErrno();
447         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
448         CallOnErrorCallback(resErr, MakeErrnoString());
449         return;
450     }
451     sockFd_ = sock;
452 }
453 
ReadMessage()454 int TLSSocket::ReadMessage()
455 {
456     char buffer[MAX_RECV_BUFFER_SIZE];
457     if (memset_s(buffer, MAX_RECV_BUFFER_SIZE, 0, MAX_RECV_BUFFER_SIZE) != EOK) {
458         NETSTACK_LOGE("memset_s failed!");
459         return -1;
460     }
461     nfds_t num = 1;
462     pollfd fds[1] = {{.fd = sockFd_, .events = POLLIN}};
463     int ret = poll(fds, num, READ_TIMEOUT_MS);
464     if (ret < 0) {
465         if (errno == EAGAIN || errno == EINTR) {
466             return 0;
467         }
468         int resErr = ConvertErrno();
469         NETSTACK_LOGE("Message poll errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
470         CallOnErrorCallback(resErr, MakeErrnoString());
471         return ret;
472     } else if (ret == 0) {
473         NETSTACK_LOGD("tls recv poll timeout");
474         return ret;
475     }
476 
477     std::lock_guard<std::mutex> lock(recvMutex_);
478     if (!isRunning_) {
479         return -1;
480     }
481     int len = tlsSocketInternal_.Recv(buffer, MAX_RECV_BUFFER_SIZE);
482     if (len < 0) {
483         if (errno == EAGAIN || errno == EINTR || len == SSL_WANT_READ_RETURN) {
484             return 0;
485         }
486         int resErr = tlsSocketInternal_.ConvertSSLError();
487         NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s",
488                       resErr, MakeSSLErrorString(resErr).c_str());
489         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
490         return len;
491     } else if (len == 0) {
492         NETSTACK_LOGI("Message recv len 0, session is closed by peer");
493         CallOnCloseCallback();
494         return -1;
495     }
496     Socket::SocketRemoteInfo remoteInfo;
497     remoteInfo.SetSize(len);
498     tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
499     std::string bufContent(buffer, len);
500     CallOnMessageCallback(bufContent, remoteInfo);
501 
502     return ret;
503 }
504 
StartReadMessage()505 void TLSSocket::StartReadMessage()
506 {
507     std::thread thread([this]() {
508         isRunning_ = true;
509         isRunOver_ = false;
510 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
511         pthread_setname_np(TLS_SOCKET_CLIENT_READ);
512 #else
513         pthread_setname_np(pthread_self(), TLS_SOCKET_CLIENT_READ);
514 #endif
515         while (isRunning_) {
516             int ret = ReadMessage();
517             if (ret < 0) {
518                 break;
519             }
520         }
521         isRunOver_ = true;
522         cvSslFree_.notify_one();
523     });
524     thread.detach();
525 }
526 
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)527 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
528 {
529     OnMessageCallback func = nullptr;
530     {
531         std::lock_guard<std::mutex> lock(mutex_);
532         if (onMessageCallback_) {
533             func = onMessageCallback_;
534         }
535     }
536 
537     if (func) {
538         func(data, remoteInfo);
539     }
540 }
541 
CallOnConnectCallback()542 void TLSSocket::CallOnConnectCallback()
543 {
544     OnConnectCallback func = nullptr;
545     {
546         std::lock_guard<std::mutex> lock(mutex_);
547         if (onConnectCallback_) {
548             func = onConnectCallback_;
549         }
550     }
551 
552     if (func) {
553         func();
554     }
555 }
556 
CallOnCloseCallback()557 void TLSSocket::CallOnCloseCallback()
558 {
559     OnCloseCallback func = nullptr;
560     {
561         std::lock_guard<std::mutex> lock(mutex_);
562         if (onCloseCallback_) {
563             func = onCloseCallback_;
564         }
565     }
566 
567     if (func) {
568         func();
569     }
570 }
571 
CallOnErrorCallback(int32_t err,const std::string & errString)572 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
573 {
574     OnErrorCallback func = nullptr;
575     {
576         std::lock_guard<std::mutex> lock(mutex_);
577         if (onErrorCallback_) {
578             func = onErrorCallback_;
579         }
580     }
581 
582     if (func) {
583         func(err, errString);
584     }
585 }
586 
CallBindCallback(int32_t err,BindCallback callback)587 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
588 {
589     DealCallback<BindCallback>(err, callback);
590 }
591 
CallConnectCallback(int32_t err,ConnectCallback callback)592 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
593 {
594     DealCallback<ConnectCallback>(err, callback);
595 }
596 
CallSendCallback(int32_t err,SendCallback callback)597 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
598 {
599     DealCallback<SendCallback>(err, callback);
600 }
601 
CallCloseCallback(int32_t err,CloseCallback callback)602 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
603 {
604     DealCallback<CloseCallback>(err, callback);
605 }
606 
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)607 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
608                                              GetRemoteAddressCallback callback)
609 {
610     GetRemoteAddressCallback func = nullptr;
611     {
612         std::lock_guard<std::mutex> lock(mutex_);
613         if (callback) {
614             func = callback;
615         }
616     }
617 
618     if (func) {
619         func(err, address);
620     }
621 }
622 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)623 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
624 {
625     GetStateCallback func = nullptr;
626     {
627         std::lock_guard<std::mutex> lock(mutex_);
628         if (callback) {
629             func = callback;
630         }
631     }
632 
633     if (func) {
634         func(err, state);
635     }
636 }
637 
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)638 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
639 {
640     DealCallback<SetExtraOptionsCallback>(err, callback);
641 }
642 
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)643 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
644 {
645     GetCertificateCallback func = nullptr;
646     {
647         std::lock_guard<std::mutex> lock(mutex_);
648         if (callback) {
649             func = callback;
650         }
651     }
652 
653     if (func) {
654         func(err, cert);
655     }
656 }
657 
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)658 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
659                                                  GetRemoteCertificateCallback callback)
660 {
661     GetRemoteCertificateCallback func = nullptr;
662     {
663         std::lock_guard<std::mutex> lock(mutex_);
664         if (callback) {
665             func = callback;
666         }
667     }
668 
669     if (func) {
670         func(err, cert);
671     }
672 }
673 
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)674 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
675 {
676     GetProtocolCallback func = nullptr;
677     {
678         std::lock_guard<std::mutex> lock(mutex_);
679         if (callback) {
680             func = callback;
681         }
682     }
683 
684     if (func) {
685         func(err, protocol);
686     }
687 }
688 
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)689 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
690                                            GetCipherSuiteCallback callback)
691 {
692     GetCipherSuiteCallback func = nullptr;
693     {
694         std::lock_guard<std::mutex> lock(mutex_);
695         if (callback) {
696             func = callback;
697         }
698     }
699 
700     if (func) {
701         func(err, suite);
702     }
703 }
704 
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)705 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
706                                                    GetSignatureAlgorithmsCallback callback)
707 {
708     GetSignatureAlgorithmsCallback func = nullptr;
709     {
710         std::lock_guard<std::mutex> lock(mutex_);
711         if (callback) {
712             func = callback;
713         }
714     }
715 
716     if (func) {
717         func(err, algorithms);
718     }
719 }
720 
Bind(Socket::NetAddress & address,const BindCallback & callback)721 void TLSSocket::Bind(Socket::NetAddress &address, const BindCallback &callback)
722 {
723     static constexpr int32_t PARSE_ERROR_CODE = 401;
724     if (!CommonUtils::HasInternetPermission()) {
725         CallBindCallback(PERMISSION_DENIED_CODE, callback);
726         return;
727     }
728     if (sockFd_ >= 0) {
729         CallBindCallback(TLSSOCKET_SUCCESS, callback);
730         return;
731     }
732 
733     MakeIpSocket(address.GetSaFamily());
734     if (sockFd_ < 0) {
735         int resErr = ConvertErrno();
736         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
737         CallOnErrorCallback(resErr, MakeErrnoString());
738         CallBindCallback(resErr, callback);
739         return;
740     }
741 
742     auto temp = address.GetAddress();
743     address.SetRawAddress("");
744     address.SetAddress(temp);
745     if (address.GetAddress().empty()) {
746         CallBindCallback(PARSE_ERROR_CODE, callback);
747         return;
748     }
749 
750     sockaddr_in addr4 = {0};
751     sockaddr_in6 addr6 = {0};
752     sockaddr *addr = nullptr;
753     socklen_t len;
754     GetAddr(address, &addr4, &addr6, &addr, &len);
755     if (addr == nullptr) {
756         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
757         CallOnErrorCallback(-1, "Address Is Invalid");
758         CallBindCallback(ConvertErrno(), callback);
759         return;
760     }
761     CallBindCallback(TLSSOCKET_SUCCESS, callback);
762 }
763 
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)764 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
765                         const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
766 {
767     if (sockFd_ < 0) {
768         int resErr = ConvertErrno();
769         NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
770         CallOnErrorCallback(resErr, MakeErrnoString());
771         callback(resErr);
772         return;
773     }
774 
775     if (isExtSock_ && !SetSockBlockFlag(sockFd_, false)) {
776         int resErr = ConvertErrno();
777         NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
778         CallOnErrorCallback(resErr, MakeErrnoString());
779         callback(resErr);
780         return;
781     }
782 
783     auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions, isExtSock_);
784     if (!res) {
785         int resErr = tlsSocketInternal_.ConvertSSLError();
786         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
787         callback(resErr);
788         return;
789     }
790     if (!SetSockBlockFlag(sockFd_, true)) {
791         int resErr = ConvertErrno();
792         NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
793         CallOnErrorCallback(resErr, MakeErrnoString());
794         callback(resErr);
795         return;
796     }
797     StartReadMessage();
798     CallOnConnectCallback();
799     callback(TLSSOCKET_SUCCESS);
800 }
801 
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)802 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
803 {
804     (void)tcpSendOptions;
805 
806     auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
807     if (!res) {
808         int resErr = tlsSocketInternal_.ConvertSSLError();
809         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
810         CallSendCallback(resErr, callback);
811         return;
812     }
813     CallSendCallback(TLSSOCKET_SUCCESS, callback);
814 }
815 
Close(const CloseCallback & callback)816 void TLSSocket::Close(const CloseCallback &callback)
817 {
818     isRunning_ = false;
819     std::unique_lock<std::mutex> cvLock(cvMutex_);
820     cvSslFree_.wait(cvLock, [this]() -> bool { return isRunOver_; });
821 
822     std::lock_guard<std::mutex> lock(recvMutex_);
823     auto res = tlsSocketInternal_.Close();
824     if (!res) {
825         int resErr = tlsSocketInternal_.ConvertSSLError();
826         NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
827         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
828         callback(resErr);
829         return;
830     }
831     CallOnCloseCallback();
832     callback(TLSSOCKET_SUCCESS);
833 }
834 
GetRemoteAddress(const GetRemoteAddressCallback & callback)835 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
836 {
837     sockaddr sockAddr = {0};
838     socklen_t len = sizeof(sockaddr);
839     int ret = getsockname(sockFd_, &sockAddr, &len);
840     if (ret < 0) {
841         int resErr = ConvertErrno();
842         NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
843         CallOnErrorCallback(resErr, MakeErrnoString());
844         CallGetRemoteAddressCallback(resErr, {}, callback);
845         return;
846     }
847 
848     if (sockAddr.sa_family == AF_INET) {
849         GetIp4RemoteAddress(callback);
850     } else if (sockAddr.sa_family == AF_INET6) {
851         GetIp6RemoteAddress(callback);
852     }
853 }
854 
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)855 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
856 {
857     sockaddr_in addr4 = {0};
858     socklen_t len4 = sizeof(sockaddr_in);
859 
860     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
861     if (ret < 0) {
862         int resErr = ConvertErrno();
863         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
864         CallOnErrorCallback(resErr, MakeErrnoString());
865         CallGetRemoteAddressCallback(resErr, {}, callback);
866         return;
867     }
868 
869     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
870     if (address.empty()) {
871         NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
872         CallOnErrorCallback(-1, "Address is invalid");
873         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
874         return;
875     }
876     Socket::NetAddress netAddress;
877     netAddress.SetFamilyBySaFamily(AF_INET);
878     netAddress.SetRawAddress(address);
879     netAddress.SetPort(ntohs(addr4.sin_port));
880     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
881 }
882 
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)883 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
884 {
885     sockaddr_in6 addr6 = {0};
886     socklen_t len6 = sizeof(sockaddr_in6);
887 
888     int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
889     if (ret < 0) {
890         int resErr = ConvertErrno();
891         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
892         CallOnErrorCallback(resErr, MakeErrnoString());
893         CallGetRemoteAddressCallback(resErr, {}, callback);
894         return;
895     }
896 
897     std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
898     if (address.empty()) {
899         NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
900         CallOnErrorCallback(-1, "Address is invalid");
901         CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
902         return;
903     }
904     Socket::NetAddress netAddress;
905     netAddress.SetFamilyBySaFamily(AF_INET6);
906     netAddress.SetRawAddress(address);
907     netAddress.SetPort(ntohs(addr6.sin6_port));
908     CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
909 }
910 
GetState(const GetStateCallback & callback)911 void TLSSocket::GetState(const GetStateCallback &callback)
912 {
913     int opt;
914     socklen_t optLen = sizeof(int);
915     int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
916     if (r < 0) {
917         Socket::SocketStateBase state;
918         state.SetIsClose(true);
919         CallGetStateCallback(ConvertErrno(), state, callback);
920         return;
921     }
922     sockaddr sockAddr = {0};
923     socklen_t len = sizeof(sockaddr);
924     Socket::SocketStateBase state;
925     int ret = getsockname(sockFd_, &sockAddr, &len);
926     state.SetIsBound(ret == 0);
927     ret = getpeername(sockFd_, &sockAddr, &len);
928     state.SetIsConnected(ret == 0);
929     CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
930 }
931 
SetBaseOptions(const Socket::ExtraOptionsBase & option) const932 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
933 {
934     if (option.GetReceiveBufferSize() != 0) {
935         int size = (int)option.GetReceiveBufferSize();
936         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
937             return false;
938         }
939     }
940 
941     if (option.GetSendBufferSize() != 0) {
942         int size = (int)option.GetSendBufferSize();
943         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
944             return false;
945         }
946     }
947 
948     if (option.IsReuseAddress()) {
949         int reuse = 1;
950         if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
951             return false;
952         }
953     }
954 
955     if (option.GetSocketTimeout() != 0) {
956         timeval timeout = {(int)option.GetSocketTimeout(), 0};
957         if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
958             return false;
959         }
960         if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
961             return false;
962         }
963     }
964 
965     return true;
966 }
967 
SetExtraOptions(const Socket::TCPExtraOptions & option) const968 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
969 {
970     if (option.IsKeepAlive()) {
971         int keepalive = 1;
972         if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
973             return false;
974         }
975     }
976 
977     if (option.IsOOBInline()) {
978         int oobInline = 1;
979         if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
980             return false;
981         }
982     }
983 
984     if (option.IsTCPNoDelay()) {
985         int tcpNoDelay = 1;
986         if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
987             return false;
988         }
989     }
990 
991     linger soLinger = {0};
992     soLinger.l_onoff = option.socketLinger.IsOn();
993     soLinger.l_linger = (int)option.socketLinger.GetLinger();
994     if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
995         return false;
996     }
997 
998     return true;
999 }
1000 
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)1001 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
1002                                 const SetExtraOptionsCallback &callback)
1003 {
1004     if (!SetBaseOptions(tcpExtraOptions)) {
1005         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1006         CallOnErrorCallback(errno, MakeErrnoString());
1007         CallSetExtraOptionsCallback(ConvertErrno(), callback);
1008         return;
1009     }
1010 
1011     if (!SetExtraOptions(tcpExtraOptions)) {
1012         NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1013         CallOnErrorCallback(errno, MakeErrnoString());
1014         CallSetExtraOptionsCallback(ConvertErrno(), callback);
1015         return;
1016     }
1017 
1018     CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
1019 }
1020 
GetCertificate(const GetCertificateCallback & callback)1021 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
1022 {
1023     const auto &cert = tlsSocketInternal_.GetCertificate();
1024     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
1025 
1026     if (!cert.data.Length()) {
1027         int resErr = tlsSocketInternal_.ConvertSSLError();
1028         NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1029         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1030         callback(resErr, {});
1031         return;
1032     }
1033     callback(TLSSOCKET_SUCCESS, cert);
1034 }
1035 
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)1036 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
1037 {
1038     const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
1039     if (!remoteCert.data.Length()) {
1040         int resErr = tlsSocketInternal_.ConvertSSLError();
1041         NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1042         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1043         callback(resErr, {});
1044         return;
1045     }
1046     callback(TLSSOCKET_SUCCESS, remoteCert);
1047 }
1048 
GetProtocol(const GetProtocolCallback & callback)1049 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
1050 {
1051     const auto &protocol = tlsSocketInternal_.GetProtocol();
1052     if (protocol.empty()) {
1053         NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
1054         int resErr = tlsSocketInternal_.ConvertSSLError();
1055         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1056         callback(resErr, "");
1057         return;
1058     }
1059     callback(TLSSOCKET_SUCCESS, protocol);
1060 }
1061 
GetCipherSuite(const GetCipherSuiteCallback & callback)1062 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
1063 {
1064     const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
1065     if (cipherSuite.empty()) {
1066         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
1067         int resErr = tlsSocketInternal_.ConvertSSLError();
1068         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1069         callback(resErr, cipherSuite);
1070         return;
1071     }
1072     callback(TLSSOCKET_SUCCESS, cipherSuite);
1073 }
1074 
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)1075 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
1076 {
1077     const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
1078     if (signatureAlgorithms.empty()) {
1079         NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
1080         int resErr = tlsSocketInternal_.ConvertSSLError();
1081         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1082         callback(resErr, {});
1083         return;
1084     }
1085     callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
1086 }
1087 
OnMessage(const OnMessageCallback & onMessageCallback)1088 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
1089 {
1090     std::lock_guard<std::mutex> lock(mutex_);
1091     onMessageCallback_ = onMessageCallback;
1092 }
1093 
OffMessage()1094 void TLSSocket::OffMessage()
1095 {
1096     std::lock_guard<std::mutex> lock(mutex_);
1097     if (onMessageCallback_) {
1098         onMessageCallback_ = nullptr;
1099     }
1100 }
1101 
OnConnect(const OnConnectCallback & onConnectCallback)1102 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1103 {
1104     std::lock_guard<std::mutex> lock(mutex_);
1105     onConnectCallback_ = onConnectCallback;
1106 }
1107 
OffConnect()1108 void TLSSocket::OffConnect()
1109 {
1110     std::lock_guard<std::mutex> lock(mutex_);
1111     if (onConnectCallback_) {
1112         onConnectCallback_ = nullptr;
1113     }
1114 }
1115 
OnClose(const OnCloseCallback & onCloseCallback)1116 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1117 {
1118     std::lock_guard<std::mutex> lock(mutex_);
1119     onCloseCallback_ = onCloseCallback;
1120 }
1121 
OffClose()1122 void TLSSocket::OffClose()
1123 {
1124     std::lock_guard<std::mutex> lock(mutex_);
1125     if (onCloseCallback_) {
1126         onCloseCallback_ = nullptr;
1127     }
1128 }
1129 
OnError(const OnErrorCallback & onErrorCallback)1130 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1131 {
1132     std::lock_guard<std::mutex> lock(mutex_);
1133     onErrorCallback_ = onErrorCallback;
1134 }
1135 
OffError()1136 void TLSSocket::OffError()
1137 {
1138     std::lock_guard<std::mutex> lock(mutex_);
1139     if (onErrorCallback_) {
1140         onErrorCallback_ = nullptr;
1141     }
1142 }
1143 
GetSocketFd()1144 int TLSSocket::GetSocketFd()
1145 {
1146     return sockFd_;
1147 }
1148 
SetLocalAddress(const Socket::NetAddress & address)1149 void TLSSocket::SetLocalAddress(const Socket::NetAddress &address)
1150 {
1151     localAddress_ = address;
1152 }
1153 
GetLocalAddress()1154 Socket::NetAddress TLSSocket::GetLocalAddress()
1155 {
1156     return localAddress_;
1157 }
1158 
GetCloseState()1159 bool TLSSocket::GetCloseState()
1160 {
1161     return isClosed;
1162 }
1163 
SetCloseState(bool flag)1164 void TLSSocket::SetCloseState(bool flag)
1165 {
1166     isClosed = flag;
1167 }
1168 
GetCloseLock()1169 std::mutex &TLSSocket::GetCloseLock()
1170 {
1171     return mutexForClose_;
1172 }
1173 
ExecSocketConnect(const std::string & host,int port,sa_family_t family,int socketDescriptor)1174 bool ExecSocketConnect(const std::string &host, int port, sa_family_t family, int socketDescriptor)
1175 {
1176     auto hostName = ConvertAddressToIp(host, family);
1177     struct sockaddr_in dest = {0};
1178     dest.sin_family = family;
1179     dest.sin_port = htons(port);
1180 
1181     sockaddr_in addr4 = {0};
1182     sockaddr_in6 addr6 = {0};
1183     sockaddr *addr = nullptr;
1184     socklen_t len = 0;
1185     if (family == AF_INET) {
1186         if (inet_pton(AF_INET, hostName.c_str(), &addr4.sin_addr.s_addr) <= 0) {
1187             return false;
1188         }
1189         addr4.sin_family = family;
1190         addr4.sin_port = htons(port);
1191         addr = reinterpret_cast<sockaddr *>(&addr4);
1192         len = sizeof(sockaddr_in);
1193     } else {
1194         if (inet_pton(AF_INET6, hostName.c_str(), &addr6.sin6_addr) <= 0) {
1195             return false;
1196         }
1197         addr6.sin6_family = family;
1198         addr6.sin6_port = htons(port);
1199         addr = reinterpret_cast<sockaddr *>(&addr6);
1200         len = sizeof(sockaddr_in6);
1201     }
1202 
1203     int connectResult = connect(socketDescriptor, addr, len);
1204     if (connectResult == -1) {
1205         NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1206                       strerror(errno));
1207         return false;
1208     }
1209     return true;
1210 }
1211 
ConvertSSLError(void)1212 int TLSSocket::TLSSocketInternal::ConvertSSLError(void)
1213 {
1214     std::lock_guard<std::mutex> lock(mutexForSsl_);
1215     if (!ssl_) {
1216         return TLS_ERR_SSL_NULL;
1217     }
1218     return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1219 }
1220 
TlsConnectToHost(int sock,const TLSConnectOptions & options,bool isExtSock)1221 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock)
1222 {
1223     SetTlsConfiguration(options);
1224     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1225     if (!cipherSuite.empty()) {
1226         configuration_.SetCipherSuite(cipherSuite);
1227     }
1228     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1229     if (!signatureAlgorithms.empty()) {
1230         configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1231     }
1232     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1233     if (!protocolVec.empty()) {
1234         configuration_.SetProtocol(protocolVec);
1235     }
1236     configuration_.SetSkipFlag(options.GetSkipRemoteValidation());
1237     hostName_ = options.GetNetAddress().GetAddress();
1238     port_ = options.GetNetAddress().GetPort();
1239     family_ = options.GetNetAddress().GetSaFamily();
1240     socketDescriptor_ = sock;
1241     if (!isExtSock && !ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1242                                          options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1243         return false;
1244     }
1245     return StartTlsConnected(options);
1246 }
1247 
SetTlsConfiguration(const TLSConnectOptions & config)1248 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1249 {
1250     configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1251     configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1252     configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1253     configuration_.SetNetAddress(config.GetNetAddress());
1254 }
1255 
SendRetry(ssl_st * ssl,const char * curPos,size_t curSendSize,int sockfd)1256 bool TLSSocket::TLSSocketInternal::SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd)
1257 {
1258     pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1259     for (int i = 0; i <= SEND_RETRY_TIMES; i++) {
1260         int ret = poll(fds, 1, SEND_POLL_TIMEOUT_MS);
1261         if (ret < 0) {
1262             if (errno == EAGAIN || errno == EINTR) {
1263                 continue;
1264             }
1265             NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1266             return false;
1267         } else if (ret == 0) {
1268             NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1269             continue;
1270         }
1271         int len = SSL_write(ssl, curPos, curSendSize);
1272         if (len < 0) {
1273             int err = SSL_get_error(ssl, SSL_RET_CODE);
1274             if (err == SSL_ERROR_WANT_WRITE || errno == EAGAIN) {
1275                 NETSTACK_LOGI("write retry times: %{public}d err: %{public}d errno: %{public}d", i, err, errno);
1276                 continue;
1277             } else {
1278                 NETSTACK_LOGE("write failed err: %{public}d errno: %{public}d", err, errno);
1279                 return false;
1280             }
1281         } else if (len == 0) {
1282             NETSTACK_LOGI("send len is 0, should have sent len");
1283             return false;
1284         } else {
1285             return true;
1286         }
1287     }
1288     return false;
1289 }
1290 
PollSend(int sockfd,ssl_st * ssl,const char * pdata,int sendSize)1291 bool TLSSocket::TLSSocketInternal::PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize)
1292 {
1293     int bufferSize = DEFAULT_BUFFER_SIZE;
1294     auto curPos = pdata;
1295     nfds_t num = 1;
1296     pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1297     while (sendSize > 0) {
1298         int ret = poll(fds, num, DEFAULT_POLL_TIMEOUT_MS);
1299         if (ret < 0) {
1300             if (errno == EAGAIN || errno == EINTR) {
1301                 continue;
1302             }
1303             NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1304             return false;
1305         } else if (ret == 0) {
1306             NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1307             continue;
1308         }
1309         std::lock_guard<std::mutex> lock(mutexForSsl_);
1310         if (!ssl) {
1311             NETSTACK_LOGE("ssl is null");
1312             return false;
1313         }
1314         size_t curSendSize = std::min<size_t>(sendSize, bufferSize);
1315         int len = SSL_write(ssl, curPos, curSendSize);
1316         if (len < 0) {
1317             int err = SSL_get_error(ssl, SSL_RET_CODE);
1318             if (err != SSL_ERROR_WANT_WRITE || errno != EAGAIN) {
1319                 NETSTACK_LOGE("write failed, return, err: %{public}d errno: %{public}d", err, errno);
1320                 return false;
1321             } else if (!SendRetry(ssl, curPos, curSendSize, sockfd)) {
1322                 return false;
1323             }
1324         } else if (len == 0) {
1325             NETSTACK_LOGI("send len is 0, should have sent len is %{public}d", sendSize);
1326             return false;
1327         }
1328         curPos += len;
1329         sendSize -= len;
1330     }
1331     return true;
1332 }
1333 
Send(const std::string & data)1334 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1335 {
1336     {
1337         std::lock_guard<std::mutex> lock(mutexForSsl_);
1338         if (!ssl_) {
1339             NETSTACK_LOGE("ssl is null");
1340             return false;
1341         }
1342     }
1343 
1344     if (data.empty()) {
1345         NETSTACK_LOGE("data is empty");
1346         return true;
1347     }
1348 
1349     if (!PollSend(socketDescriptor_, ssl_, data.c_str(), data.size())) {
1350         return false;
1351     }
1352     return true;
1353 }
Recv(char * buffer,int maxBufferSize)1354 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1355 {
1356     if (!ssl_) {
1357         NETSTACK_LOGE("ssl is null");
1358         return SSL_ERROR_RETURN;
1359     }
1360 
1361     int ret = SSL_read(ssl_, buffer, maxBufferSize);
1362     if (ret < 0) {
1363         int err = SSL_get_error(ssl_, SSL_RET_CODE);
1364         switch (err) {
1365             case SSL_ERROR_SSL:
1366                 NETSTACK_LOGE("An error occurred in the SSL library");
1367                 return SSL_ERROR_RETURN;
1368             case SSL_ERROR_ZERO_RETURN:
1369                 NETSTACK_LOGE("peer disconnected...");
1370                 return SSL_ERROR_RETURN;
1371             case SSL_ERROR_WANT_READ:
1372                 NETSTACK_LOGD("SSL_read function no data available for reading, try again at a later time");
1373                 return SSL_WANT_READ_RETURN;
1374             default:
1375                 NETSTACK_LOGE("SSL_read function failed, error code is %{public}d", err);
1376                 return SSL_ERROR_RETURN;
1377         }
1378     }
1379     return ret;
1380 }
1381 
Close()1382 bool TLSSocket::TLSSocketInternal::Close()
1383 {
1384     std::lock_guard<std::mutex> lock(mutexForSsl_);
1385     if (!ssl_) {
1386         NETSTACK_LOGE("ssl is null, fd =%{public}d", socketDescriptor_);
1387         return false;
1388     }
1389     int result = SSL_shutdown(ssl_);
1390     if (result < 0) {
1391         int resErr = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1392         NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1393                       MakeSSLErrorString(resErr).c_str());
1394     }
1395     NETSTACK_LOGI("tls socket close, fd =%{public}d", socketDescriptor_);
1396     SSL_free(ssl_);
1397     ssl_ = nullptr;
1398     close(socketDescriptor_);
1399     socketDescriptor_ = -1;
1400     if (!tlsContextPointer_) {
1401         NETSTACK_LOGE("Tls context pointer is null");
1402         return false;
1403     }
1404     tlsContextPointer_->CloseCtx();
1405     return true;
1406 }
1407 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1408 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1409 {
1410     if (!ssl_) {
1411         NETSTACK_LOGE("ssl is null");
1412         return false;
1413     }
1414     size_t pos = 0;
1415     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1416                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1417     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1418     for (const auto &str : alpnProtocols) {
1419         len = str.length();
1420         result[pos++] = len;
1421         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1422             NETSTACK_LOGE("strcpy_s failed");
1423             return false;
1424         }
1425         pos += len;
1426     }
1427     result[pos] = '\0';
1428 
1429     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1430     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1431         int resErr = ConvertSSLError();
1432         NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1433                       MakeSSLErrorString(resErr).c_str());
1434         return false;
1435     }
1436     return true;
1437 }
1438 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1439 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1440 {
1441     remoteInfo.SetFamily(family_);
1442     remoteInfo.SetAddress(hostName_);
1443     remoteInfo.SetPort(port_);
1444 }
1445 
GetTlsConfiguration() const1446 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1447 {
1448     return configuration_;
1449 }
1450 
GetCipherSuite() const1451 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1452 {
1453     if (!ssl_) {
1454         NETSTACK_LOGE("ssl in null");
1455         return {};
1456     }
1457     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1458     if (!sk) {
1459         NETSTACK_LOGE("get ciphers failed");
1460         return {};
1461     }
1462     CipherSuite cipherSuite;
1463     std::vector<std::string> cipherSuiteVec;
1464     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1465         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1466         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1467         cipherSuiteVec.push_back(cipherSuite.cipherName_);
1468     }
1469     return cipherSuiteVec;
1470 }
1471 
GetRemoteCertificate() const1472 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1473 {
1474     return remoteCert_;
1475 }
1476 
GetCertificate() const1477 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1478 {
1479     return configuration_.GetCertificate();
1480 }
1481 
GetSignatureAlgorithms() const1482 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1483 {
1484     return signatureAlgorithms_;
1485 }
1486 
GetProtocol() const1487 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1488 {
1489     if (!ssl_) {
1490         NETSTACK_LOGE("ssl in null");
1491         return PROTOCOL_UNKNOW;
1492     }
1493     if (configuration_.GetProtocol() == TLS_V1_3) {
1494         return PROTOCOL_TLS_V13;
1495     }
1496     return PROTOCOL_TLS_V12;
1497 }
1498 
SetSharedSigals()1499 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1500 {
1501     if (!ssl_) {
1502         NETSTACK_LOGE("ssl is null");
1503         return false;
1504     }
1505     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1506     if (!number) {
1507         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1508         return false;
1509     }
1510     for (int i = 0; i < number; i++) {
1511         int hash_nid;
1512         int sign_nid;
1513         std::string sig_with_md;
1514         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1515         switch (sign_nid) {
1516             case EVP_PKEY_RSA:
1517                 sig_with_md = SIGN_NID_RSA;
1518                 break;
1519             case EVP_PKEY_RSA_PSS:
1520                 sig_with_md = SIGN_NID_RSA_PSS;
1521                 break;
1522             case EVP_PKEY_DSA:
1523                 sig_with_md = SIGN_NID_DSA;
1524                 break;
1525             case EVP_PKEY_EC:
1526                 sig_with_md = SIGN_NID_ECDSA;
1527                 break;
1528             case NID_ED25519:
1529                 sig_with_md = SIGN_NID_ED;
1530                 break;
1531             case NID_ED448:
1532                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1533                 break;
1534             default:
1535                 const char *sn = OBJ_nid2sn(sign_nid);
1536                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1537         }
1538         const char *sn_hash = OBJ_nid2sn(hash_nid);
1539         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1540         signatureAlgorithms_.push_back(sig_with_md);
1541     }
1542     return true;
1543 }
1544 
StartTlsConnected(const TLSConnectOptions & options)1545 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1546 {
1547     if (!CreatTlsContext()) {
1548         NETSTACK_LOGE("failed to create tls context");
1549         return false;
1550     }
1551     if (!StartShakingHands(options)) {
1552         NETSTACK_LOGE("failed to shaking hands");
1553         return false;
1554     }
1555     return true;
1556 }
1557 
CreatTlsContext()1558 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1559 {
1560     tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1561     if (!tlsContextPointer_) {
1562         NETSTACK_LOGE("failed to create tls context pointer");
1563         return false;
1564     }
1565 
1566     std::lock_guard<std::mutex> lock(mutexForSsl_);
1567     if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1568         NETSTACK_LOGE("failed to create ssl session");
1569         return false;
1570     }
1571 
1572     SSL_set_fd(ssl_, socketDescriptor_);
1573     SSL_set_connect_state(ssl_);
1574     return true;
1575 }
1576 
StartsWith(const std::string & s,const std::string & prefix)1577 static bool StartsWith(const std::string &s, const std::string &prefix)
1578 {
1579     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1580 }
1581 
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1582 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1583                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1584 {
1585     bool valid = false;
1586     std::string reason = UNKNOW_REASON;
1587     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1588     if (IsIP(hostName)) {
1589         auto it = find(ips.begin(), ips.end(), hostName);
1590         if (it == ips.end()) {
1591             reason = IP + hostName + " is not in the cert's list";
1592         }
1593         result = {valid, reason};
1594         return;
1595     }
1596     std::string tempHostName = "" + hostName;
1597     if (!dnsNames.empty() || index > 0) {
1598         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1599         if (!dnsNames.empty()) {
1600             valid = SeekIntersection(hostParts, dnsNames);
1601             if (!valid) {
1602                 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1603             }
1604         } else {
1605             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1606             X509_NAME *pSubName = nullptr;
1607             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1608             if (len > 0) {
1609                 std::vector<std::string> commonNameVec;
1610                 commonNameVec.emplace_back(commonNameBuf);
1611                 valid = SeekIntersection(hostParts, commonNameVec);
1612                 if (!valid) {
1613                     reason = HOST_NAME + tempHostName + ". is not cert's CN";
1614                 }
1615             }
1616         }
1617         result = {valid, reason};
1618         return;
1619     }
1620     reason = "Cert does not contain a DNS name";
1621     result = {valid, reason};
1622 }
1623 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1624 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1625                                                                    const X509 *x509Certificates)
1626 {
1627     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1628     if (!subjectName) {
1629         return "subject name is null";
1630     }
1631     char subNameBuf[BUF_SIZE] = {0};
1632     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1633 
1634     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1635     if (index < 0) {
1636         return "X509 get ext nid error";
1637     }
1638     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1639     if (ext == nullptr) {
1640         return "X509 get ext error";
1641     }
1642     ASN1_OBJECT *obj = nullptr;
1643     obj = X509_EXTENSION_get_object(ext);
1644     char subAltNameBuf[BUF_SIZE] = {0};
1645     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1646     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1647 
1648     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1649 }
1650 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1651 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1652                                                                    const X509 *x509Certificates)
1653 {
1654     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1655     if (!extData) {
1656         NETSTACK_LOGE("extData is nullptr");
1657         return "";
1658     }
1659 
1660     std::string altNames = reinterpret_cast<char *>(extData->data);
1661     std::string hostname = " " + hostName;
1662     BIO *bio = BIO_new(BIO_s_file());
1663     if (!bio) {
1664         return "bio is null";
1665     }
1666     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1667     ASN1_STRING_print(bio, extData);
1668     std::vector<std::string> dnsNames = {};
1669     std::vector<std::string> ips = {};
1670     constexpr int DNS_NAME_IDX = 4;
1671     constexpr int IP_NAME_IDX = 11;
1672     if (!altNames.empty()) {
1673         std::vector<std::string> splitAltNames;
1674         if (altNames.find('\"') != std::string::npos) {
1675             splitAltNames = SplitEscapedAltNames(altNames);
1676         } else {
1677             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1678         }
1679         for (auto const &iter : splitAltNames) {
1680             if (StartsWith(iter, DNS)) {
1681                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1682             } else if (StartsWith(iter, IP_ADDRESS)) {
1683                 ips.push_back(iter.substr(IP_NAME_IDX));
1684             }
1685         }
1686     }
1687     std::tuple<bool, std::string> result;
1688     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1689     if (!std::get<0>(result)) {
1690         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1691     }
1692     return HOST_NAME + hostname + ". is cert's CN";
1693 }
1694 
LoadCaCertFromMemory(X509_STORE * store,const std::string & pemCerts)1695 static void LoadCaCertFromMemory(X509_STORE *store, const std::string &pemCerts)
1696 {
1697     if (!store || pemCerts.empty() || pemCerts.size() > static_cast<size_t>(INT_MAX)) {
1698         return;
1699     }
1700 
1701     auto cbio = BIO_new_mem_buf(pemCerts.data(), static_cast<int>(pemCerts.size()));
1702     if (!cbio) {
1703         return;
1704     }
1705 
1706     auto inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr);
1707     if (!inf) {
1708         BIO_free(cbio);
1709         return;
1710     }
1711 
1712     /* add each entry from PEM file to x509_store */
1713     for (int i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); ++i) {
1714         auto itmp = sk_X509_INFO_value(inf, i);
1715         if (!itmp) {
1716             continue;
1717         }
1718         if (itmp->x509) {
1719             X509_STORE_add_cert(store, itmp->x509);
1720         }
1721         if (itmp->crl) {
1722             X509_STORE_add_crl(store, itmp->crl);
1723         }
1724     }
1725 
1726     sk_X509_INFO_pop_free(inf, X509_INFO_free);
1727     BIO_free(cbio);
1728 }
1729 
X509_to_PEM(X509 * cert)1730 static std::string X509_to_PEM(X509 *cert)
1731 {
1732     if (!cert) {
1733         return {};
1734     }
1735     BIO *bio = BIO_new(BIO_s_mem());
1736     if (!bio) {
1737         return {};
1738     }
1739     if (!PEM_write_bio_X509(bio, cert)) {
1740         BIO_free(bio);
1741         return {};
1742     }
1743 
1744     char *data = nullptr;
1745     auto pemStringLength = BIO_get_mem_data(bio, &data);
1746     if (!data) {
1747         BIO_free(bio);
1748         return {};
1749     }
1750     std::string certificateInPEM(data, pemStringLength);
1751     BIO_free(bio);
1752     return certificateInPEM;
1753 }
1754 
CacheCertificates(const std::string & hostName,SSL * ssl)1755 static void CacheCertificates(const std::string &hostName, SSL *ssl)
1756 {
1757     if (!ssl || hostName.empty()) {
1758         return;
1759     }
1760     auto certificatesStack = SSL_get_peer_cert_chain(ssl);
1761     if (!certificatesStack) {
1762         return;
1763     }
1764     auto numCertificates = sk_X509_num(certificatesStack);
1765     for (auto i = 0; i < numCertificates; ++i) {
1766         auto cert = sk_X509_value(certificatesStack, i);
1767         auto certificateInPEM = X509_to_PEM(cert);
1768         if (!certificateInPEM.empty()) {
1769             CaCertCache::GetInstance().Set(hostName, certificateInPEM);
1770         }
1771     }
1772 }
1773 
LoadCachedCaCert(const std::string & hostName,SSL * ssl)1774 static void LoadCachedCaCert(const std::string &hostName, SSL *ssl)
1775 {
1776     if (!ssl) {
1777         return;
1778     }
1779     auto cachedPem = CaCertCache::GetInstance().Get(hostName);
1780     auto sslCtx = SSL_get_SSL_CTX(ssl);
1781     if (!sslCtx) {
1782         return;
1783     }
1784     auto x509Store = SSL_CTX_get_cert_store(sslCtx);
1785     if (!x509Store) {
1786         return;
1787     }
1788     for (const auto &pem : cachedPem) {
1789         LoadCaCertFromMemory(x509Store, pem);
1790     }
1791 }
1792 
StartShakingHands(const TLSConnectOptions & options)1793 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1794 {
1795     {
1796         std::lock_guard<std::mutex> lock(mutexForSsl_);
1797         if (!ssl_) {
1798             NETSTACK_LOGE("ssl is null");
1799             return false;
1800         }
1801 
1802         auto hostName = options.GetHostName();
1803         // indicates hostName is not ip address
1804         if (hostName != options.GetNetAddress().GetAddress()) {
1805             LoadCachedCaCert(hostName, ssl_);
1806         }
1807 
1808         int result = SSL_connect(ssl_);
1809         if (result == -1) {
1810             char err[MAX_ERR_LEN] = {0};
1811             auto code = ERR_get_error();
1812             ERR_error_string_n(code, err, MAX_ERR_LEN);
1813             int errorStatus = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1814             NETSTACK_LOGE("SSLConnect fail %{public}d, error: %{public}s errno: %{public}d ERR_get_error %{public}s",
1815                           errorStatus, MakeSSLErrorString(errorStatus).c_str(), errno, err);
1816             return false;
1817         }
1818 
1819         // indicates hostName is not ip address
1820         if (hostName != options.GetNetAddress().GetAddress()) {
1821             CacheCertificates(hostName, ssl_);
1822         }
1823 
1824         std::string list = SSL_get_cipher_list(ssl_, 0);
1825         NETSTACK_LOGI("cipher_list: %{public}s, Version: %{public}s, Cipher: %{public}s", list.c_str(),
1826                       SSL_get_version(ssl_), SSL_get_cipher(ssl_));
1827         configuration_.SetCipherSuite(list);
1828     }
1829     if (!SetSharedSigals()) {
1830         NETSTACK_LOGE("Failed to set sharedSigalgs");
1831     }
1832     if (!GetRemoteCertificateFromPeer()) {
1833         NETSTACK_LOGE("Failed to get remote certificate");
1834     }
1835     if (!peerX509_) {
1836         NETSTACK_LOGE("peer x509Certificates is null");
1837         return false;
1838     }
1839     if (!SetRemoteCertRawData()) {
1840         NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1841     }
1842     CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1843     if (!checkServerIdentity) {
1844         CheckServerIdentityLegal(hostName_, peerX509_);
1845     } else {
1846         checkServerIdentity(hostName_, {remoteCert_});
1847     }
1848     return true;
1849 }
1850 
GetRemoteCertificateFromPeer()1851 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1852 {
1853     peerX509_ = SSL_get_peer_certificate(ssl_);
1854     if (peerX509_ == nullptr) {
1855         int resErr = ConvertSSLError();
1856         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1857                       MakeSSLErrorString(resErr).c_str());
1858         return false;
1859     }
1860     BIO *bio = BIO_new(BIO_s_mem());
1861     if (!bio) {
1862         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1863         return false;
1864     }
1865     X509_print(bio, peerX509_);
1866     char data[REMOTE_CERT_LEN] = {0};
1867     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1868         NETSTACK_LOGE("BIO_read function returns error");
1869         BIO_free(bio);
1870         return false;
1871     }
1872     BIO_free(bio);
1873     remoteCert_ = std::string(data);
1874     return true;
1875 }
1876 
SetRemoteCertRawData()1877 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1878 {
1879     if (peerX509_ == nullptr) {
1880         NETSTACK_LOGE("peerX509 is null");
1881         return false;
1882     }
1883     int32_t length = i2d_X509(peerX509_, nullptr);
1884     if (length <= 0) {
1885         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1886         return false;
1887     }
1888     unsigned char *der = nullptr;
1889     (void)i2d_X509(peerX509_, &der);
1890     SecureData data(der, length);
1891     remoteRawData_.data = data;
1892     OPENSSL_free(der);
1893     remoteRawData_.encodingFormat = DER;
1894     return true;
1895 }
1896 
GetRemoteCertRawData() const1897 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1898 {
1899     return remoteRawData_;
1900 }
1901 
GetSSL()1902 ssl_st *TLSSocket::TLSSocketInternal::GetSSL()
1903 {
1904     std::lock_guard<std::mutex> lock(mutexForSsl_);
1905     return ssl_;
1906 }
1907 } // namespace TlsSocket
1908 } // namespace NetStack
1909 } // namespace OHOS
1910