1 /*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tls_socket.h"
17
18 #include <chrono>
19 #include <memory>
20 #include <numeric>
21 #include <regex>
22 #include <securec.h>
23 #include <set>
24 #include <thread>
25 #include <poll.h>
26
27 #include <netinet/tcp.h>
28 #include <openssl/err.h>
29 #include <openssl/ssl.h>
30
31 #include "base_context.h"
32 #include "netstack_common_utils.h"
33 #include "netstack_log.h"
34 #include "tls.h"
35 #include "socket_exec_common.h"
36
37 namespace OHOS {
38 namespace NetStack {
39 namespace TlsSocket {
40 namespace {
41 constexpr int READ_TIMEOUT_MS = 500;
42 constexpr int REMOTE_CERT_LEN = 8192;
43 constexpr int COMMON_NAME_BUF_SIZE = 256;
44 constexpr int BUF_SIZE = 2048;
45 constexpr int SSL_RET_CODE = 0;
46 constexpr int SSL_ERROR_RETURN = -1;
47 constexpr int SSL_WANT_READ_RETURN = -2;
48 constexpr int OFFSET = 2;
49 constexpr int DEFAULT_BUFFER_SIZE = 8192;
50 constexpr int DEFAULT_POLL_TIMEOUT_MS = 500;
51 constexpr int SEND_RETRY_TIMES = 5;
52 constexpr int SEND_POLL_TIMEOUT_MS = 1000;
53 constexpr int MAX_RECV_BUFFER_SIZE = 1024 * 16;
54 constexpr const char *SPLIT_ALT_NAMES = ",";
55 constexpr const char *SPLIT_HOST_NAME = ".";
56 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
57 constexpr const char *UNKNOW_REASON = "Unknown reason";
58 constexpr const char *IP = "IP: ";
59 constexpr const char *HOST_NAME = "hostname: ";
60 constexpr const char *DNS = "DNS:";
61 constexpr const char *IP_ADDRESS = "IP Address:";
62 constexpr const char *SIGN_NID_RSA = "RSA+";
63 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
64 constexpr const char *SIGN_NID_DSA = "DSA+";
65 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
66 constexpr const char *SIGN_NID_ED = "Ed25519+";
67 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
68 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
69 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
70 constexpr const char *OPERATOR_PLUS_SIGN = "+";
71 static constexpr const char *TLS_SOCKET_CLIENT_READ = "OS_NET_TSCliRD";
72 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
73 const std::regex PATTERN{
74 "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
75 "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
76
77 class CaCertCache {
78 public:
GetInstance()79 static CaCertCache &GetInstance()
80 {
81 static CaCertCache instance;
82 return instance;
83 }
84
Get(const std::string & key)85 std::set<std::string> Get(const std::string &key)
86 {
87 std::lock_guard l(mutex_);
88 auto it = map_.find(key);
89 if (it != map_.end()) {
90 return it->second;
91 }
92 return {};
93 }
94
Set(const std::string & key,const std::string & val)95 void Set(const std::string &key, const std::string &val)
96 {
97 std::lock_guard l(mutex_);
98 map_[key].insert(val);
99 }
100
101 private:
102 CaCertCache() = default;
103 ~CaCertCache() = default;
104 CaCertCache &operator=(const CaCertCache &) = delete;
105 CaCertCache(const CaCertCache &) = delete;
106
107 std::map<std::string, std::set<std::string>> map_;
108 std::mutex mutex_;
109 };
110
ConvertErrno()111 int ConvertErrno()
112 {
113 return TlsSocketError::TLS_ERR_SYS_BASE + errno;
114 }
115
MakeErrnoString()116 std::string MakeErrnoString()
117 {
118 return strerror(errno);
119 }
120
MakeSSLErrorString(int error)121 std::string MakeSSLErrorString(int error)
122 {
123 char err[MAX_ERR_LEN] = {0};
124 ERR_error_string_n(error - TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
125 return err;
126 }
127
SplitEscapedAltNames(std::string & altNames)128 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
129 {
130 std::vector<std::string> result;
131 std::string currentToken;
132 size_t offset = 0;
133 while (offset != altNames.length()) {
134 auto nextSep = altNames.find_first_of(", ");
135 auto nextQuote = altNames.find_first_of('\"');
136 if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
137 currentToken += altNames.substr(offset, nextQuote);
138 std::regex jsonStringPattern(JSON_STRING_PATTERN);
139 std::smatch match;
140 std::string altNameSubStr = altNames.substr(nextQuote);
141 bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
142 if (!ret) {
143 return {""};
144 }
145 currentToken += result[0];
146 offset = nextQuote + result[0].length();
147 } else if (nextSep != std::string::npos) {
148 currentToken += altNames.substr(offset, nextSep);
149 result.push_back(currentToken);
150 currentToken = "";
151 offset = nextSep + OFFSET;
152 } else {
153 currentToken += altNames.substr(offset);
154 offset = altNames.length();
155 }
156 }
157 result.push_back(currentToken);
158 return result;
159 }
160
IsIP(const std::string & ip)161 bool IsIP(const std::string &ip)
162 {
163 std::regex pattern(PATTERN);
164 std::smatch res;
165 return regex_match(ip, res, pattern);
166 }
167
SplitHostName(std::string & hostName)168 std::vector<std::string> SplitHostName(std::string &hostName)
169 {
170 transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
171 return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
172 }
173
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)174 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
175 {
176 std::vector<std::string> result;
177 set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
178 return !result.empty();
179 }
180 } // namespace
181
SetSockBlockFlag(int sock,bool noneBlock)182 static bool SetSockBlockFlag(int sock, bool noneBlock)
183 {
184 int flags = fcntl(sock, F_GETFL, 0);
185 while (flags == -1 && errno == EINTR) {
186 flags = fcntl(sock, F_GETFL, 0);
187 }
188 if (flags == -1) {
189 NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
190 return false;
191 }
192
193 auto newFlags = static_cast<size_t>(flags);
194 if (noneBlock) {
195 newFlags |= static_cast<size_t>(O_NONBLOCK);
196 } else {
197 newFlags &= ~static_cast<size_t>(O_NONBLOCK);
198 }
199
200 int ret = fcntl(sock, F_SETFL, newFlags);
201 while (ret == -1 && errno == EINTR) {
202 ret = fcntl(sock, F_SETFL, newFlags);
203 }
204 if (ret == -1) {
205 NETSTACK_LOGE("set block flags failed, socket is %{public}d, errno is %{public}d", sock, errno);
206 return false;
207 }
208 return true;
209 }
210
TLSSecureOptions(const TLSSecureOptions & tlsSecureOptions)211 TLSSecureOptions::TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions)
212 {
213 *this = tlsSecureOptions;
214 }
215
operator =(const TLSSecureOptions & tlsSecureOptions)216 TLSSecureOptions &TLSSecureOptions::operator=(const TLSSecureOptions &tlsSecureOptions)
217 {
218 key_ = tlsSecureOptions.GetKey();
219 caChain_ = tlsSecureOptions.GetCaChain();
220 cert_ = tlsSecureOptions.GetCert();
221 protocolChain_ = tlsSecureOptions.GetProtocolChain();
222 crlChain_ = tlsSecureOptions.GetCrlChain();
223 keyPass_ = tlsSecureOptions.GetKeyPass();
224 key_ = tlsSecureOptions.GetKey();
225 signatureAlgorithms_ = tlsSecureOptions.GetSignatureAlgorithms();
226 cipherSuite_ = tlsSecureOptions.GetCipherSuite();
227 useRemoteCipherPrefer_ = tlsSecureOptions.UseRemoteCipherPrefer();
228 TLSVerifyMode_ = tlsSecureOptions.GetVerifyMode();
229 return *this;
230 }
231
SetCaChain(const std::vector<std::string> & caChain)232 void TLSSecureOptions::SetCaChain(const std::vector<std::string> &caChain)
233 {
234 caChain_ = caChain;
235 }
236
SetCert(const std::string & cert)237 void TLSSecureOptions::SetCert(const std::string &cert)
238 {
239 cert_ = cert;
240 }
241
SetKey(const SecureData & key)242 void TLSSecureOptions::SetKey(const SecureData &key)
243 {
244 key_ = key;
245 }
246
SetKeyPass(const SecureData & keyPass)247 void TLSSecureOptions::SetKeyPass(const SecureData &keyPass)
248 {
249 keyPass_ = keyPass;
250 }
251
SetProtocolChain(const std::vector<std::string> & protocolChain)252 void TLSSecureOptions::SetProtocolChain(const std::vector<std::string> &protocolChain)
253 {
254 protocolChain_ = protocolChain;
255 }
256
SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)257 void TLSSecureOptions::SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer)
258 {
259 useRemoteCipherPrefer_ = useRemoteCipherPrefer;
260 }
261
SetSignatureAlgorithms(const std::string & signatureAlgorithms)262 void TLSSecureOptions::SetSignatureAlgorithms(const std::string &signatureAlgorithms)
263 {
264 signatureAlgorithms_ = signatureAlgorithms;
265 }
266
SetCipherSuite(const std::string & cipherSuite)267 void TLSSecureOptions::SetCipherSuite(const std::string &cipherSuite)
268 {
269 cipherSuite_ = cipherSuite;
270 }
271
SetCrlChain(const std::vector<std::string> & crlChain)272 void TLSSecureOptions::SetCrlChain(const std::vector<std::string> &crlChain)
273 {
274 crlChain_ = crlChain;
275 }
276
GetCaChain() const277 const std::vector<std::string> &TLSSecureOptions::GetCaChain() const
278 {
279 return caChain_;
280 }
281
GetCert() const282 const std::string &TLSSecureOptions::GetCert() const
283 {
284 return cert_;
285 }
286
GetKey() const287 const SecureData &TLSSecureOptions::GetKey() const
288 {
289 return key_;
290 }
291
GetKeyPass() const292 const SecureData &TLSSecureOptions::GetKeyPass() const
293 {
294 return keyPass_;
295 }
296
GetProtocolChain() const297 const std::vector<std::string> &TLSSecureOptions::GetProtocolChain() const
298 {
299 return protocolChain_;
300 }
301
UseRemoteCipherPrefer() const302 bool TLSSecureOptions::UseRemoteCipherPrefer() const
303 {
304 return useRemoteCipherPrefer_;
305 }
306
GetSignatureAlgorithms() const307 const std::string &TLSSecureOptions::GetSignatureAlgorithms() const
308 {
309 return signatureAlgorithms_;
310 }
311
GetCipherSuite() const312 const std::string &TLSSecureOptions::GetCipherSuite() const
313 {
314 return cipherSuite_;
315 }
316
GetCrlChain() const317 const std::vector<std::string> &TLSSecureOptions::GetCrlChain() const
318 {
319 return crlChain_;
320 }
321
SetVerifyMode(VerifyMode verifyMode)322 void TLSSecureOptions::SetVerifyMode(VerifyMode verifyMode)
323 {
324 TLSVerifyMode_ = verifyMode;
325 }
326
GetVerifyMode() const327 VerifyMode TLSSecureOptions::GetVerifyMode() const
328 {
329 return TLSVerifyMode_;
330 }
331
SetNetAddress(const Socket::NetAddress & address)332 void TLSConnectOptions::SetNetAddress(const Socket::NetAddress &address)
333 {
334 address_.SetFamilyBySaFamily(address.GetSaFamily());
335 address_.SetRawAddress(address.GetAddress());
336 address_.SetPort(address.GetPort());
337 }
338
SetTlsSecureOptions(TLSSecureOptions & tlsSecureOptions)339 void TLSConnectOptions::SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions)
340 {
341 tlsSecureOptions_ = tlsSecureOptions;
342 }
343
SetCheckServerIdentity(const CheckServerIdentity & checkServerIdentity)344 void TLSConnectOptions::SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity)
345 {
346 checkServerIdentity_ = checkServerIdentity;
347 }
348
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)349 void TLSConnectOptions::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
350 {
351 alpnProtocols_ = alpnProtocols;
352 }
353
SetSkipRemoteValidation(bool skipRemoteValidation)354 void TLSConnectOptions::SetSkipRemoteValidation(bool skipRemoteValidation)
355 {
356 skipRemoteValidation_ = skipRemoteValidation;
357 }
358
GetNetAddress() const359 Socket::NetAddress TLSConnectOptions::GetNetAddress() const
360 {
361 return address_;
362 }
363
GetTlsSecureOptions() const364 TLSSecureOptions TLSConnectOptions::GetTlsSecureOptions() const
365 {
366 return tlsSecureOptions_;
367 }
368
GetCheckServerIdentity() const369 CheckServerIdentity TLSConnectOptions::GetCheckServerIdentity() const
370 {
371 return checkServerIdentity_;
372 }
373
GetAlpnProtocols() const374 const std::vector<std::string> &TLSConnectOptions::GetAlpnProtocols() const
375 {
376 return alpnProtocols_;
377 }
378
GetSkipRemoteValidation() const379 bool TLSConnectOptions::GetSkipRemoteValidation() const
380 {
381 return skipRemoteValidation_;
382 }
383
SetHostName(const std::string & hostName)384 void TLSConnectOptions::SetHostName(const std::string &hostName)
385 {
386 hostName_ = hostName;
387 }
388
GetHostName() const389 std::string TLSConnectOptions::GetHostName() const
390 {
391 return hostName_;
392 }
393
MakeAddressString(sockaddr * addr)394 std::string TLSSocket::MakeAddressString(sockaddr *addr)
395 {
396 if (!addr) {
397 return {};
398 }
399 if (addr->sa_family == AF_INET) {
400 auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
401 const char *str = inet_ntoa(addr4->sin_addr);
402 if (str == nullptr || strlen(str) == 0) {
403 return {};
404 }
405 return str;
406 } else if (addr->sa_family == AF_INET6) {
407 auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
408 char str[INET6_ADDRSTRLEN] = {0};
409 if (inet_ntop(AF_INET6, &addr6->sin6_addr, str, INET6_ADDRSTRLEN) == nullptr || strlen(str) == 0) {
410 return {};
411 }
412 return str;
413 }
414 return {};
415 }
416
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)417 void TLSSocket::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
418 socklen_t *len)
419 {
420 if (!addr6 || !addr4 || !len) {
421 return;
422 }
423 sa_family_t family = address.GetSaFamily();
424 if (family == AF_INET) {
425 addr4->sin_family = AF_INET;
426 addr4->sin_port = htons(address.GetPort());
427 addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
428 *addr = reinterpret_cast<sockaddr *>(addr4);
429 *len = sizeof(sockaddr_in);
430 } else if (family == AF_INET6) {
431 addr6->sin6_family = AF_INET6;
432 addr6->sin6_port = htons(address.GetPort());
433 inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
434 *addr = reinterpret_cast<sockaddr *>(addr6);
435 *len = sizeof(sockaddr_in6);
436 }
437 }
438
MakeIpSocket(sa_family_t family)439 void TLSSocket::MakeIpSocket(sa_family_t family)
440 {
441 if (family != AF_INET && family != AF_INET6) {
442 return;
443 }
444 int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
445 if (sock < 0) {
446 int resErr = ConvertErrno();
447 NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
448 CallOnErrorCallback(resErr, MakeErrnoString());
449 return;
450 }
451 sockFd_ = sock;
452 }
453
ReadMessage()454 int TLSSocket::ReadMessage()
455 {
456 char buffer[MAX_RECV_BUFFER_SIZE];
457 if (memset_s(buffer, MAX_RECV_BUFFER_SIZE, 0, MAX_RECV_BUFFER_SIZE) != EOK) {
458 NETSTACK_LOGE("memset_s failed!");
459 return -1;
460 }
461 nfds_t num = 1;
462 pollfd fds[1] = {{.fd = sockFd_, .events = POLLIN}};
463 int ret = poll(fds, num, READ_TIMEOUT_MS);
464 if (ret < 0) {
465 if (errno == EAGAIN || errno == EINTR) {
466 return 0;
467 }
468 int resErr = ConvertErrno();
469 NETSTACK_LOGE("Message poll errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
470 CallOnErrorCallback(resErr, MakeErrnoString());
471 return ret;
472 } else if (ret == 0) {
473 NETSTACK_LOGD("tls recv poll timeout");
474 return ret;
475 }
476
477 std::lock_guard<std::mutex> lock(recvMutex_);
478 if (!isRunning_) {
479 return -1;
480 }
481 int len = tlsSocketInternal_.Recv(buffer, MAX_RECV_BUFFER_SIZE);
482 if (len < 0) {
483 if (errno == EAGAIN || errno == EINTR || len == SSL_WANT_READ_RETURN) {
484 return 0;
485 }
486 int resErr = tlsSocketInternal_.ConvertSSLError();
487 NETSTACK_LOGE("SSL_read function read error, errno is %{public}d, errno info is %{public}s",
488 resErr, MakeSSLErrorString(resErr).c_str());
489 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
490 return len;
491 } else if (len == 0) {
492 NETSTACK_LOGI("Message recv len 0, session is closed by peer");
493 CallOnCloseCallback();
494 return -1;
495 }
496 Socket::SocketRemoteInfo remoteInfo;
497 remoteInfo.SetSize(len);
498 tlsSocketInternal_.MakeRemoteInfo(remoteInfo);
499 std::string bufContent(buffer, len);
500 CallOnMessageCallback(bufContent, remoteInfo);
501
502 return ret;
503 }
504
StartReadMessage()505 void TLSSocket::StartReadMessage()
506 {
507 std::thread thread([this]() {
508 isRunning_ = true;
509 isRunOver_ = false;
510 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
511 pthread_setname_np(TLS_SOCKET_CLIENT_READ);
512 #else
513 pthread_setname_np(pthread_self(), TLS_SOCKET_CLIENT_READ);
514 #endif
515 while (isRunning_) {
516 int ret = ReadMessage();
517 if (ret < 0) {
518 break;
519 }
520 }
521 isRunOver_ = true;
522 cvSslFree_.notify_one();
523 });
524 thread.detach();
525 }
526
CallOnMessageCallback(const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)527 void TLSSocket::CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)
528 {
529 OnMessageCallback func = nullptr;
530 {
531 std::lock_guard<std::mutex> lock(mutex_);
532 if (onMessageCallback_) {
533 func = onMessageCallback_;
534 }
535 }
536
537 if (func) {
538 func(data, remoteInfo);
539 }
540 }
541
CallOnConnectCallback()542 void TLSSocket::CallOnConnectCallback()
543 {
544 OnConnectCallback func = nullptr;
545 {
546 std::lock_guard<std::mutex> lock(mutex_);
547 if (onConnectCallback_) {
548 func = onConnectCallback_;
549 }
550 }
551
552 if (func) {
553 func();
554 }
555 }
556
CallOnCloseCallback()557 void TLSSocket::CallOnCloseCallback()
558 {
559 OnCloseCallback func = nullptr;
560 {
561 std::lock_guard<std::mutex> lock(mutex_);
562 if (onCloseCallback_) {
563 func = onCloseCallback_;
564 }
565 }
566
567 if (func) {
568 func();
569 }
570 }
571
CallOnErrorCallback(int32_t err,const std::string & errString)572 void TLSSocket::CallOnErrorCallback(int32_t err, const std::string &errString)
573 {
574 OnErrorCallback func = nullptr;
575 {
576 std::lock_guard<std::mutex> lock(mutex_);
577 if (onErrorCallback_) {
578 func = onErrorCallback_;
579 }
580 }
581
582 if (func) {
583 func(err, errString);
584 }
585 }
586
CallBindCallback(int32_t err,BindCallback callback)587 void TLSSocket::CallBindCallback(int32_t err, BindCallback callback)
588 {
589 DealCallback<BindCallback>(err, callback);
590 }
591
CallConnectCallback(int32_t err,ConnectCallback callback)592 void TLSSocket::CallConnectCallback(int32_t err, ConnectCallback callback)
593 {
594 DealCallback<ConnectCallback>(err, callback);
595 }
596
CallSendCallback(int32_t err,SendCallback callback)597 void TLSSocket::CallSendCallback(int32_t err, SendCallback callback)
598 {
599 DealCallback<SendCallback>(err, callback);
600 }
601
CallCloseCallback(int32_t err,CloseCallback callback)602 void TLSSocket::CallCloseCallback(int32_t err, CloseCallback callback)
603 {
604 DealCallback<CloseCallback>(err, callback);
605 }
606
CallGetRemoteAddressCallback(int32_t err,const Socket::NetAddress & address,GetRemoteAddressCallback callback)607 void TLSSocket::CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
608 GetRemoteAddressCallback callback)
609 {
610 GetRemoteAddressCallback func = nullptr;
611 {
612 std::lock_guard<std::mutex> lock(mutex_);
613 if (callback) {
614 func = callback;
615 }
616 }
617
618 if (func) {
619 func(err, address);
620 }
621 }
622
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,GetStateCallback callback)623 void TLSSocket::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback)
624 {
625 GetStateCallback func = nullptr;
626 {
627 std::lock_guard<std::mutex> lock(mutex_);
628 if (callback) {
629 func = callback;
630 }
631 }
632
633 if (func) {
634 func(err, state);
635 }
636 }
637
CallSetExtraOptionsCallback(int32_t err,SetExtraOptionsCallback callback)638 void TLSSocket::CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback)
639 {
640 DealCallback<SetExtraOptionsCallback>(err, callback);
641 }
642
CallGetCertificateCallback(int32_t err,const X509CertRawData & cert,GetCertificateCallback callback)643 void TLSSocket::CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback)
644 {
645 GetCertificateCallback func = nullptr;
646 {
647 std::lock_guard<std::mutex> lock(mutex_);
648 if (callback) {
649 func = callback;
650 }
651 }
652
653 if (func) {
654 func(err, cert);
655 }
656 }
657
CallGetRemoteCertificateCallback(int32_t err,const X509CertRawData & cert,GetRemoteCertificateCallback callback)658 void TLSSocket::CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
659 GetRemoteCertificateCallback callback)
660 {
661 GetRemoteCertificateCallback func = nullptr;
662 {
663 std::lock_guard<std::mutex> lock(mutex_);
664 if (callback) {
665 func = callback;
666 }
667 }
668
669 if (func) {
670 func(err, cert);
671 }
672 }
673
CallGetProtocolCallback(int32_t err,const std::string & protocol,GetProtocolCallback callback)674 void TLSSocket::CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback)
675 {
676 GetProtocolCallback func = nullptr;
677 {
678 std::lock_guard<std::mutex> lock(mutex_);
679 if (callback) {
680 func = callback;
681 }
682 }
683
684 if (func) {
685 func(err, protocol);
686 }
687 }
688
CallGetCipherSuiteCallback(int32_t err,const std::vector<std::string> & suite,GetCipherSuiteCallback callback)689 void TLSSocket::CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
690 GetCipherSuiteCallback callback)
691 {
692 GetCipherSuiteCallback func = nullptr;
693 {
694 std::lock_guard<std::mutex> lock(mutex_);
695 if (callback) {
696 func = callback;
697 }
698 }
699
700 if (func) {
701 func(err, suite);
702 }
703 }
704
CallGetSignatureAlgorithmsCallback(int32_t err,const std::vector<std::string> & algorithms,GetSignatureAlgorithmsCallback callback)705 void TLSSocket::CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
706 GetSignatureAlgorithmsCallback callback)
707 {
708 GetSignatureAlgorithmsCallback func = nullptr;
709 {
710 std::lock_guard<std::mutex> lock(mutex_);
711 if (callback) {
712 func = callback;
713 }
714 }
715
716 if (func) {
717 func(err, algorithms);
718 }
719 }
720
Bind(Socket::NetAddress & address,const BindCallback & callback)721 void TLSSocket::Bind(Socket::NetAddress &address, const BindCallback &callback)
722 {
723 static constexpr int32_t PARSE_ERROR_CODE = 401;
724 if (!CommonUtils::HasInternetPermission()) {
725 CallBindCallback(PERMISSION_DENIED_CODE, callback);
726 return;
727 }
728 if (sockFd_ >= 0) {
729 CallBindCallback(TLSSOCKET_SUCCESS, callback);
730 return;
731 }
732
733 MakeIpSocket(address.GetSaFamily());
734 if (sockFd_ < 0) {
735 int resErr = ConvertErrno();
736 NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
737 CallOnErrorCallback(resErr, MakeErrnoString());
738 CallBindCallback(resErr, callback);
739 return;
740 }
741
742 auto temp = address.GetAddress();
743 address.SetRawAddress("");
744 address.SetAddress(temp);
745 if (address.GetAddress().empty()) {
746 CallBindCallback(PARSE_ERROR_CODE, callback);
747 return;
748 }
749
750 sockaddr_in addr4 = {0};
751 sockaddr_in6 addr6 = {0};
752 sockaddr *addr = nullptr;
753 socklen_t len;
754 GetAddr(address, &addr4, &addr6, &addr, &len);
755 if (addr == nullptr) {
756 NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
757 CallOnErrorCallback(-1, "Address Is Invalid");
758 CallBindCallback(ConvertErrno(), callback);
759 return;
760 }
761 CallBindCallback(TLSSOCKET_SUCCESS, callback);
762 }
763
Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions & tlsConnectOptions,const OHOS::NetStack::TlsSocket::ConnectCallback & callback)764 void TLSSocket::Connect(OHOS::NetStack::TlsSocket::TLSConnectOptions &tlsConnectOptions,
765 const OHOS::NetStack::TlsSocket::ConnectCallback &callback)
766 {
767 if (sockFd_ < 0) {
768 int resErr = ConvertErrno();
769 NETSTACK_LOGE("connect error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
770 CallOnErrorCallback(resErr, MakeErrnoString());
771 callback(resErr);
772 return;
773 }
774
775 if (isExtSock_ && !SetSockBlockFlag(sockFd_, false)) {
776 int resErr = ConvertErrno();
777 NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
778 CallOnErrorCallback(resErr, MakeErrnoString());
779 callback(resErr);
780 return;
781 }
782
783 auto res = tlsSocketInternal_.TlsConnectToHost(sockFd_, tlsConnectOptions, isExtSock_);
784 if (!res) {
785 int resErr = tlsSocketInternal_.ConvertSSLError();
786 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
787 callback(resErr);
788 return;
789 }
790 if (!SetSockBlockFlag(sockFd_, true)) {
791 int resErr = ConvertErrno();
792 NETSTACK_LOGE("SetSockBlockFlag error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
793 CallOnErrorCallback(resErr, MakeErrnoString());
794 callback(resErr);
795 return;
796 }
797 StartReadMessage();
798 CallOnConnectCallback();
799 callback(TLSSOCKET_SUCCESS);
800 }
801
Send(const OHOS::NetStack::Socket::TCPSendOptions & tcpSendOptions,const SendCallback & callback)802 void TLSSocket::Send(const OHOS::NetStack::Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback)
803 {
804 (void)tcpSendOptions;
805
806 auto res = tlsSocketInternal_.Send(tcpSendOptions.GetData());
807 if (!res) {
808 int resErr = tlsSocketInternal_.ConvertSSLError();
809 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
810 CallSendCallback(resErr, callback);
811 return;
812 }
813 CallSendCallback(TLSSOCKET_SUCCESS, callback);
814 }
815
Close(const CloseCallback & callback)816 void TLSSocket::Close(const CloseCallback &callback)
817 {
818 isRunning_ = false;
819 std::unique_lock<std::mutex> cvLock(cvMutex_);
820 cvSslFree_.wait(cvLock, [this]() -> bool { return isRunOver_; });
821
822 std::lock_guard<std::mutex> lock(recvMutex_);
823 auto res = tlsSocketInternal_.Close();
824 if (!res) {
825 int resErr = tlsSocketInternal_.ConvertSSLError();
826 NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
827 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
828 callback(resErr);
829 return;
830 }
831 CallOnCloseCallback();
832 callback(TLSSOCKET_SUCCESS);
833 }
834
GetRemoteAddress(const GetRemoteAddressCallback & callback)835 void TLSSocket::GetRemoteAddress(const GetRemoteAddressCallback &callback)
836 {
837 sockaddr sockAddr = {0};
838 socklen_t len = sizeof(sockaddr);
839 int ret = getsockname(sockFd_, &sockAddr, &len);
840 if (ret < 0) {
841 int resErr = ConvertErrno();
842 NETSTACK_LOGE("getsockname failed errno %{public}d", resErr);
843 CallOnErrorCallback(resErr, MakeErrnoString());
844 CallGetRemoteAddressCallback(resErr, {}, callback);
845 return;
846 }
847
848 if (sockAddr.sa_family == AF_INET) {
849 GetIp4RemoteAddress(callback);
850 } else if (sockAddr.sa_family == AF_INET6) {
851 GetIp6RemoteAddress(callback);
852 }
853 }
854
GetIp4RemoteAddress(const GetRemoteAddressCallback & callback)855 void TLSSocket::GetIp4RemoteAddress(const GetRemoteAddressCallback &callback)
856 {
857 sockaddr_in addr4 = {0};
858 socklen_t len4 = sizeof(sockaddr_in);
859
860 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr4), &len4);
861 if (ret < 0) {
862 int resErr = ConvertErrno();
863 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", resErr);
864 CallOnErrorCallback(resErr, MakeErrnoString());
865 CallGetRemoteAddressCallback(resErr, {}, callback);
866 return;
867 }
868
869 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr4));
870 if (address.empty()) {
871 NETSTACK_LOGE("GetIp4RemoteAddress failed errno %{public}d", errno);
872 CallOnErrorCallback(-1, "Address is invalid");
873 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
874 return;
875 }
876 Socket::NetAddress netAddress;
877 netAddress.SetFamilyBySaFamily(AF_INET);
878 netAddress.SetRawAddress(address);
879 netAddress.SetPort(ntohs(addr4.sin_port));
880 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
881 }
882
GetIp6RemoteAddress(const GetRemoteAddressCallback & callback)883 void TLSSocket::GetIp6RemoteAddress(const GetRemoteAddressCallback &callback)
884 {
885 sockaddr_in6 addr6 = {0};
886 socklen_t len6 = sizeof(sockaddr_in6);
887
888 int ret = getpeername(sockFd_, reinterpret_cast<sockaddr *>(&addr6), &len6);
889 if (ret < 0) {
890 int resErr = ConvertErrno();
891 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", resErr);
892 CallOnErrorCallback(resErr, MakeErrnoString());
893 CallGetRemoteAddressCallback(resErr, {}, callback);
894 return;
895 }
896
897 std::string address = MakeAddressString(reinterpret_cast<sockaddr *>(&addr6));
898 if (address.empty()) {
899 NETSTACK_LOGE("GetIp6RemoteAddress failed errno %{public}d", errno);
900 CallOnErrorCallback(-1, "Address is invalid");
901 CallGetRemoteAddressCallback(ConvertErrno(), {}, callback);
902 return;
903 }
904 Socket::NetAddress netAddress;
905 netAddress.SetFamilyBySaFamily(AF_INET6);
906 netAddress.SetRawAddress(address);
907 netAddress.SetPort(ntohs(addr6.sin6_port));
908 CallGetRemoteAddressCallback(TLSSOCKET_SUCCESS, netAddress, callback);
909 }
910
GetState(const GetStateCallback & callback)911 void TLSSocket::GetState(const GetStateCallback &callback)
912 {
913 int opt;
914 socklen_t optLen = sizeof(int);
915 int r = getsockopt(sockFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
916 if (r < 0) {
917 Socket::SocketStateBase state;
918 state.SetIsClose(true);
919 CallGetStateCallback(ConvertErrno(), state, callback);
920 return;
921 }
922 sockaddr sockAddr = {0};
923 socklen_t len = sizeof(sockaddr);
924 Socket::SocketStateBase state;
925 int ret = getsockname(sockFd_, &sockAddr, &len);
926 state.SetIsBound(ret == 0);
927 ret = getpeername(sockFd_, &sockAddr, &len);
928 state.SetIsConnected(ret == 0);
929 CallGetStateCallback(TLSSOCKET_SUCCESS, state, callback);
930 }
931
SetBaseOptions(const Socket::ExtraOptionsBase & option) const932 bool TLSSocket::SetBaseOptions(const Socket::ExtraOptionsBase &option) const
933 {
934 if (option.GetReceiveBufferSize() != 0) {
935 int size = (int)option.GetReceiveBufferSize();
936 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
937 return false;
938 }
939 }
940
941 if (option.GetSendBufferSize() != 0) {
942 int size = (int)option.GetSendBufferSize();
943 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<void *>(&size), sizeof(size)) < 0) {
944 return false;
945 }
946 }
947
948 if (option.IsReuseAddress()) {
949 int reuse = 1;
950 if (setsockopt(sockFd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<void *>(&reuse), sizeof(reuse)) < 0) {
951 return false;
952 }
953 }
954
955 if (option.GetSocketTimeout() != 0) {
956 timeval timeout = {(int)option.GetSocketTimeout(), 0};
957 if (setsockopt(sockFd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
958 return false;
959 }
960 if (setsockopt(sockFd_, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<void *>(&timeout), sizeof(timeout)) < 0) {
961 return false;
962 }
963 }
964
965 return true;
966 }
967
SetExtraOptions(const Socket::TCPExtraOptions & option) const968 bool TLSSocket::SetExtraOptions(const Socket::TCPExtraOptions &option) const
969 {
970 if (option.IsKeepAlive()) {
971 int keepalive = 1;
972 if (setsockopt(sockFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
973 return false;
974 }
975 }
976
977 if (option.IsOOBInline()) {
978 int oobInline = 1;
979 if (setsockopt(sockFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
980 return false;
981 }
982 }
983
984 if (option.IsTCPNoDelay()) {
985 int tcpNoDelay = 1;
986 if (setsockopt(sockFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
987 return false;
988 }
989 }
990
991 linger soLinger = {0};
992 soLinger.l_onoff = option.socketLinger.IsOn();
993 soLinger.l_linger = (int)option.socketLinger.GetLinger();
994 if (setsockopt(sockFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
995 return false;
996 }
997
998 return true;
999 }
1000
SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions & tcpExtraOptions,const SetExtraOptionsCallback & callback)1001 void TLSSocket::SetExtraOptions(const OHOS::NetStack::Socket::TCPExtraOptions &tcpExtraOptions,
1002 const SetExtraOptionsCallback &callback)
1003 {
1004 if (!SetBaseOptions(tcpExtraOptions)) {
1005 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1006 CallOnErrorCallback(errno, MakeErrnoString());
1007 CallSetExtraOptionsCallback(ConvertErrno(), callback);
1008 return;
1009 }
1010
1011 if (!SetExtraOptions(tcpExtraOptions)) {
1012 NETSTACK_LOGE("SetExtraOptions errno %{public}d", errno);
1013 CallOnErrorCallback(errno, MakeErrnoString());
1014 CallSetExtraOptionsCallback(ConvertErrno(), callback);
1015 return;
1016 }
1017
1018 CallSetExtraOptionsCallback(TLSSOCKET_SUCCESS, callback);
1019 }
1020
GetCertificate(const GetCertificateCallback & callback)1021 void TLSSocket::GetCertificate(const GetCertificateCallback &callback)
1022 {
1023 const auto &cert = tlsSocketInternal_.GetCertificate();
1024 NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
1025
1026 if (!cert.data.Length()) {
1027 int resErr = tlsSocketInternal_.ConvertSSLError();
1028 NETSTACK_LOGE("GetCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1029 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1030 callback(resErr, {});
1031 return;
1032 }
1033 callback(TLSSOCKET_SUCCESS, cert);
1034 }
1035
GetRemoteCertificate(const GetRemoteCertificateCallback & callback)1036 void TLSSocket::GetRemoteCertificate(const GetRemoteCertificateCallback &callback)
1037 {
1038 const auto &remoteCert = tlsSocketInternal_.GetRemoteCertRawData();
1039 if (!remoteCert.data.Length()) {
1040 int resErr = tlsSocketInternal_.ConvertSSLError();
1041 NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
1042 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1043 callback(resErr, {});
1044 return;
1045 }
1046 callback(TLSSOCKET_SUCCESS, remoteCert);
1047 }
1048
GetProtocol(const GetProtocolCallback & callback)1049 void TLSSocket::GetProtocol(const GetProtocolCallback &callback)
1050 {
1051 const auto &protocol = tlsSocketInternal_.GetProtocol();
1052 if (protocol.empty()) {
1053 NETSTACK_LOGE("GetProtocol errno %{public}d", errno);
1054 int resErr = tlsSocketInternal_.ConvertSSLError();
1055 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1056 callback(resErr, "");
1057 return;
1058 }
1059 callback(TLSSOCKET_SUCCESS, protocol);
1060 }
1061
GetCipherSuite(const GetCipherSuiteCallback & callback)1062 void TLSSocket::GetCipherSuite(const GetCipherSuiteCallback &callback)
1063 {
1064 const auto &cipherSuite = tlsSocketInternal_.GetCipherSuite();
1065 if (cipherSuite.empty()) {
1066 NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
1067 int resErr = tlsSocketInternal_.ConvertSSLError();
1068 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1069 callback(resErr, cipherSuite);
1070 return;
1071 }
1072 callback(TLSSOCKET_SUCCESS, cipherSuite);
1073 }
1074
GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback & callback)1075 void TLSSocket::GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback)
1076 {
1077 const auto &signatureAlgorithms = tlsSocketInternal_.GetSignatureAlgorithms();
1078 if (signatureAlgorithms.empty()) {
1079 NETSTACK_LOGE("GetSignatureAlgorithms errno %{public}d", errno);
1080 int resErr = tlsSocketInternal_.ConvertSSLError();
1081 CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1082 callback(resErr, {});
1083 return;
1084 }
1085 callback(TLSSOCKET_SUCCESS, signatureAlgorithms);
1086 }
1087
OnMessage(const OnMessageCallback & onMessageCallback)1088 void TLSSocket::OnMessage(const OnMessageCallback &onMessageCallback)
1089 {
1090 std::lock_guard<std::mutex> lock(mutex_);
1091 onMessageCallback_ = onMessageCallback;
1092 }
1093
OffMessage()1094 void TLSSocket::OffMessage()
1095 {
1096 std::lock_guard<std::mutex> lock(mutex_);
1097 if (onMessageCallback_) {
1098 onMessageCallback_ = nullptr;
1099 }
1100 }
1101
OnConnect(const OnConnectCallback & onConnectCallback)1102 void TLSSocket::OnConnect(const OnConnectCallback &onConnectCallback)
1103 {
1104 std::lock_guard<std::mutex> lock(mutex_);
1105 onConnectCallback_ = onConnectCallback;
1106 }
1107
OffConnect()1108 void TLSSocket::OffConnect()
1109 {
1110 std::lock_guard<std::mutex> lock(mutex_);
1111 if (onConnectCallback_) {
1112 onConnectCallback_ = nullptr;
1113 }
1114 }
1115
OnClose(const OnCloseCallback & onCloseCallback)1116 void TLSSocket::OnClose(const OnCloseCallback &onCloseCallback)
1117 {
1118 std::lock_guard<std::mutex> lock(mutex_);
1119 onCloseCallback_ = onCloseCallback;
1120 }
1121
OffClose()1122 void TLSSocket::OffClose()
1123 {
1124 std::lock_guard<std::mutex> lock(mutex_);
1125 if (onCloseCallback_) {
1126 onCloseCallback_ = nullptr;
1127 }
1128 }
1129
OnError(const OnErrorCallback & onErrorCallback)1130 void TLSSocket::OnError(const OnErrorCallback &onErrorCallback)
1131 {
1132 std::lock_guard<std::mutex> lock(mutex_);
1133 onErrorCallback_ = onErrorCallback;
1134 }
1135
OffError()1136 void TLSSocket::OffError()
1137 {
1138 std::lock_guard<std::mutex> lock(mutex_);
1139 if (onErrorCallback_) {
1140 onErrorCallback_ = nullptr;
1141 }
1142 }
1143
GetSocketFd()1144 int TLSSocket::GetSocketFd()
1145 {
1146 return sockFd_;
1147 }
1148
SetLocalAddress(const Socket::NetAddress & address)1149 void TLSSocket::SetLocalAddress(const Socket::NetAddress &address)
1150 {
1151 localAddress_ = address;
1152 }
1153
GetLocalAddress()1154 Socket::NetAddress TLSSocket::GetLocalAddress()
1155 {
1156 return localAddress_;
1157 }
1158
GetCloseState()1159 bool TLSSocket::GetCloseState()
1160 {
1161 return isClosed;
1162 }
1163
SetCloseState(bool flag)1164 void TLSSocket::SetCloseState(bool flag)
1165 {
1166 isClosed = flag;
1167 }
1168
GetCloseLock()1169 std::mutex &TLSSocket::GetCloseLock()
1170 {
1171 return mutexForClose_;
1172 }
1173
ExecSocketConnect(const std::string & host,int port,sa_family_t family,int socketDescriptor)1174 bool ExecSocketConnect(const std::string &host, int port, sa_family_t family, int socketDescriptor)
1175 {
1176 auto hostName = ConvertAddressToIp(host, family);
1177 struct sockaddr_in dest = {0};
1178 dest.sin_family = family;
1179 dest.sin_port = htons(port);
1180
1181 sockaddr_in addr4 = {0};
1182 sockaddr_in6 addr6 = {0};
1183 sockaddr *addr = nullptr;
1184 socklen_t len = 0;
1185 if (family == AF_INET) {
1186 if (inet_pton(AF_INET, hostName.c_str(), &addr4.sin_addr.s_addr) <= 0) {
1187 return false;
1188 }
1189 addr4.sin_family = family;
1190 addr4.sin_port = htons(port);
1191 addr = reinterpret_cast<sockaddr *>(&addr4);
1192 len = sizeof(sockaddr_in);
1193 } else {
1194 if (inet_pton(AF_INET6, hostName.c_str(), &addr6.sin6_addr) <= 0) {
1195 return false;
1196 }
1197 addr6.sin6_family = family;
1198 addr6.sin6_port = htons(port);
1199 addr = reinterpret_cast<sockaddr *>(&addr6);
1200 len = sizeof(sockaddr_in6);
1201 }
1202
1203 int connectResult = connect(socketDescriptor, addr, len);
1204 if (connectResult == -1) {
1205 NETSTACK_LOGE("socket connect error!The error code is %{public}d, The error message is %{public}s", errno,
1206 strerror(errno));
1207 return false;
1208 }
1209 return true;
1210 }
1211
ConvertSSLError(void)1212 int TLSSocket::TLSSocketInternal::ConvertSSLError(void)
1213 {
1214 std::lock_guard<std::mutex> lock(mutexForSsl_);
1215 if (!ssl_) {
1216 return TLS_ERR_SSL_NULL;
1217 }
1218 return TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1219 }
1220
TlsConnectToHost(int sock,const TLSConnectOptions & options,bool isExtSock)1221 bool TLSSocket::TLSSocketInternal::TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock)
1222 {
1223 SetTlsConfiguration(options);
1224 std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
1225 if (!cipherSuite.empty()) {
1226 configuration_.SetCipherSuite(cipherSuite);
1227 }
1228 std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
1229 if (!signatureAlgorithms.empty()) {
1230 configuration_.SetSignatureAlgorithms(signatureAlgorithms);
1231 }
1232 const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
1233 if (!protocolVec.empty()) {
1234 configuration_.SetProtocol(protocolVec);
1235 }
1236 configuration_.SetSkipFlag(options.GetSkipRemoteValidation());
1237 hostName_ = options.GetNetAddress().GetAddress();
1238 port_ = options.GetNetAddress().GetPort();
1239 family_ = options.GetNetAddress().GetSaFamily();
1240 socketDescriptor_ = sock;
1241 if (!isExtSock && !ExecSocketConnect(options.GetNetAddress().GetAddress(), options.GetNetAddress().GetPort(),
1242 options.GetNetAddress().GetSaFamily(), socketDescriptor_)) {
1243 return false;
1244 }
1245 return StartTlsConnected(options);
1246 }
1247
SetTlsConfiguration(const TLSConnectOptions & config)1248 void TLSSocket::TLSSocketInternal::SetTlsConfiguration(const TLSConnectOptions &config)
1249 {
1250 configuration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass());
1251 configuration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
1252 configuration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
1253 configuration_.SetNetAddress(config.GetNetAddress());
1254 }
1255
SendRetry(ssl_st * ssl,const char * curPos,size_t curSendSize,int sockfd)1256 bool TLSSocket::TLSSocketInternal::SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd)
1257 {
1258 pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1259 for (int i = 0; i <= SEND_RETRY_TIMES; i++) {
1260 int ret = poll(fds, 1, SEND_POLL_TIMEOUT_MS);
1261 if (ret < 0) {
1262 if (errno == EAGAIN || errno == EINTR) {
1263 continue;
1264 }
1265 NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1266 return false;
1267 } else if (ret == 0) {
1268 NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1269 continue;
1270 }
1271 int len = SSL_write(ssl, curPos, curSendSize);
1272 if (len < 0) {
1273 int err = SSL_get_error(ssl, SSL_RET_CODE);
1274 if (err == SSL_ERROR_WANT_WRITE || errno == EAGAIN) {
1275 NETSTACK_LOGI("write retry times: %{public}d err: %{public}d errno: %{public}d", i, err, errno);
1276 continue;
1277 } else {
1278 NETSTACK_LOGE("write failed err: %{public}d errno: %{public}d", err, errno);
1279 return false;
1280 }
1281 } else if (len == 0) {
1282 NETSTACK_LOGI("send len is 0, should have sent len");
1283 return false;
1284 } else {
1285 return true;
1286 }
1287 }
1288 return false;
1289 }
1290
PollSend(int sockfd,ssl_st * ssl,const char * pdata,int sendSize)1291 bool TLSSocket::TLSSocketInternal::PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize)
1292 {
1293 int bufferSize = DEFAULT_BUFFER_SIZE;
1294 auto curPos = pdata;
1295 nfds_t num = 1;
1296 pollfd fds[1] = {{.fd = sockfd, .events = POLLOUT}};
1297 while (sendSize > 0) {
1298 int ret = poll(fds, num, DEFAULT_POLL_TIMEOUT_MS);
1299 if (ret < 0) {
1300 if (errno == EAGAIN || errno == EINTR) {
1301 continue;
1302 }
1303 NETSTACK_LOGE("send poll error, fd: %{public}d, errno: %{public}d", sockfd, errno);
1304 return false;
1305 } else if (ret == 0) {
1306 NETSTACK_LOGI("send poll timeout, fd: %{public}d, errno: %{public}d", sockfd, errno);
1307 continue;
1308 }
1309 std::lock_guard<std::mutex> lock(mutexForSsl_);
1310 if (!ssl) {
1311 NETSTACK_LOGE("ssl is null");
1312 return false;
1313 }
1314 size_t curSendSize = std::min<size_t>(sendSize, bufferSize);
1315 int len = SSL_write(ssl, curPos, curSendSize);
1316 if (len < 0) {
1317 int err = SSL_get_error(ssl, SSL_RET_CODE);
1318 if (err != SSL_ERROR_WANT_WRITE || errno != EAGAIN) {
1319 NETSTACK_LOGE("write failed, return, err: %{public}d errno: %{public}d", err, errno);
1320 return false;
1321 } else if (!SendRetry(ssl, curPos, curSendSize, sockfd)) {
1322 return false;
1323 }
1324 } else if (len == 0) {
1325 NETSTACK_LOGI("send len is 0, should have sent len is %{public}d", sendSize);
1326 return false;
1327 }
1328 curPos += len;
1329 sendSize -= len;
1330 }
1331 return true;
1332 }
1333
Send(const std::string & data)1334 bool TLSSocket::TLSSocketInternal::Send(const std::string &data)
1335 {
1336 {
1337 std::lock_guard<std::mutex> lock(mutexForSsl_);
1338 if (!ssl_) {
1339 NETSTACK_LOGE("ssl is null");
1340 return false;
1341 }
1342 }
1343
1344 if (data.empty()) {
1345 NETSTACK_LOGE("data is empty");
1346 return true;
1347 }
1348
1349 if (!PollSend(socketDescriptor_, ssl_, data.c_str(), data.size())) {
1350 return false;
1351 }
1352 return true;
1353 }
Recv(char * buffer,int maxBufferSize)1354 int TLSSocket::TLSSocketInternal::Recv(char *buffer, int maxBufferSize)
1355 {
1356 if (!ssl_) {
1357 NETSTACK_LOGE("ssl is null");
1358 return SSL_ERROR_RETURN;
1359 }
1360
1361 int ret = SSL_read(ssl_, buffer, maxBufferSize);
1362 if (ret < 0) {
1363 int err = SSL_get_error(ssl_, SSL_RET_CODE);
1364 switch (err) {
1365 case SSL_ERROR_SSL:
1366 NETSTACK_LOGE("An error occurred in the SSL library");
1367 return SSL_ERROR_RETURN;
1368 case SSL_ERROR_ZERO_RETURN:
1369 NETSTACK_LOGE("peer disconnected...");
1370 return SSL_ERROR_RETURN;
1371 case SSL_ERROR_WANT_READ:
1372 NETSTACK_LOGD("SSL_read function no data available for reading, try again at a later time");
1373 return SSL_WANT_READ_RETURN;
1374 default:
1375 NETSTACK_LOGE("SSL_read function failed, error code is %{public}d", err);
1376 return SSL_ERROR_RETURN;
1377 }
1378 }
1379 return ret;
1380 }
1381
Close()1382 bool TLSSocket::TLSSocketInternal::Close()
1383 {
1384 std::lock_guard<std::mutex> lock(mutexForSsl_);
1385 if (!ssl_) {
1386 NETSTACK_LOGE("ssl is null, fd =%{public}d", socketDescriptor_);
1387 return false;
1388 }
1389 int result = SSL_shutdown(ssl_);
1390 if (result < 0) {
1391 int resErr = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1392 NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
1393 MakeSSLErrorString(resErr).c_str());
1394 }
1395 NETSTACK_LOGI("tls socket close, fd =%{public}d", socketDescriptor_);
1396 SSL_free(ssl_);
1397 ssl_ = nullptr;
1398 close(socketDescriptor_);
1399 socketDescriptor_ = -1;
1400 if (!tlsContextPointer_) {
1401 NETSTACK_LOGE("Tls context pointer is null");
1402 return false;
1403 }
1404 tlsContextPointer_->CloseCtx();
1405 return true;
1406 }
1407
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)1408 bool TLSSocket::TLSSocketInternal::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
1409 {
1410 if (!ssl_) {
1411 NETSTACK_LOGE("ssl is null");
1412 return false;
1413 }
1414 size_t pos = 0;
1415 size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
1416 [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
1417 auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
1418 for (const auto &str : alpnProtocols) {
1419 len = str.length();
1420 result[pos++] = len;
1421 if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
1422 NETSTACK_LOGE("strcpy_s failed");
1423 return false;
1424 }
1425 pos += len;
1426 }
1427 result[pos] = '\0';
1428
1429 NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
1430 if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
1431 int resErr = ConvertSSLError();
1432 NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
1433 MakeSSLErrorString(resErr).c_str());
1434 return false;
1435 }
1436 return true;
1437 }
1438
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)1439 void TLSSocket::TLSSocketInternal::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
1440 {
1441 remoteInfo.SetFamily(family_);
1442 remoteInfo.SetAddress(hostName_);
1443 remoteInfo.SetPort(port_);
1444 }
1445
GetTlsConfiguration() const1446 TLSConfiguration TLSSocket::TLSSocketInternal::GetTlsConfiguration() const
1447 {
1448 return configuration_;
1449 }
1450
GetCipherSuite() const1451 std::vector<std::string> TLSSocket::TLSSocketInternal::GetCipherSuite() const
1452 {
1453 if (!ssl_) {
1454 NETSTACK_LOGE("ssl in null");
1455 return {};
1456 }
1457 STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
1458 if (!sk) {
1459 NETSTACK_LOGE("get ciphers failed");
1460 return {};
1461 }
1462 CipherSuite cipherSuite;
1463 std::vector<std::string> cipherSuiteVec;
1464 for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
1465 const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
1466 cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
1467 cipherSuiteVec.push_back(cipherSuite.cipherName_);
1468 }
1469 return cipherSuiteVec;
1470 }
1471
GetRemoteCertificate() const1472 std::string TLSSocket::TLSSocketInternal::GetRemoteCertificate() const
1473 {
1474 return remoteCert_;
1475 }
1476
GetCertificate() const1477 const X509CertRawData &TLSSocket::TLSSocketInternal::GetCertificate() const
1478 {
1479 return configuration_.GetCertificate();
1480 }
1481
GetSignatureAlgorithms() const1482 std::vector<std::string> TLSSocket::TLSSocketInternal::GetSignatureAlgorithms() const
1483 {
1484 return signatureAlgorithms_;
1485 }
1486
GetProtocol() const1487 std::string TLSSocket::TLSSocketInternal::GetProtocol() const
1488 {
1489 if (!ssl_) {
1490 NETSTACK_LOGE("ssl in null");
1491 return PROTOCOL_UNKNOW;
1492 }
1493 if (configuration_.GetProtocol() == TLS_V1_3) {
1494 return PROTOCOL_TLS_V13;
1495 }
1496 return PROTOCOL_TLS_V12;
1497 }
1498
SetSharedSigals()1499 bool TLSSocket::TLSSocketInternal::SetSharedSigals()
1500 {
1501 if (!ssl_) {
1502 NETSTACK_LOGE("ssl is null");
1503 return false;
1504 }
1505 int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
1506 if (!number) {
1507 NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
1508 return false;
1509 }
1510 for (int i = 0; i < number; i++) {
1511 int hash_nid;
1512 int sign_nid;
1513 std::string sig_with_md;
1514 SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
1515 switch (sign_nid) {
1516 case EVP_PKEY_RSA:
1517 sig_with_md = SIGN_NID_RSA;
1518 break;
1519 case EVP_PKEY_RSA_PSS:
1520 sig_with_md = SIGN_NID_RSA_PSS;
1521 break;
1522 case EVP_PKEY_DSA:
1523 sig_with_md = SIGN_NID_DSA;
1524 break;
1525 case EVP_PKEY_EC:
1526 sig_with_md = SIGN_NID_ECDSA;
1527 break;
1528 case NID_ED25519:
1529 sig_with_md = SIGN_NID_ED;
1530 break;
1531 case NID_ED448:
1532 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
1533 break;
1534 default:
1535 const char *sn = OBJ_nid2sn(sign_nid);
1536 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
1537 }
1538 const char *sn_hash = OBJ_nid2sn(hash_nid);
1539 sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
1540 signatureAlgorithms_.push_back(sig_with_md);
1541 }
1542 return true;
1543 }
1544
StartTlsConnected(const TLSConnectOptions & options)1545 bool TLSSocket::TLSSocketInternal::StartTlsConnected(const TLSConnectOptions &options)
1546 {
1547 if (!CreatTlsContext()) {
1548 NETSTACK_LOGE("failed to create tls context");
1549 return false;
1550 }
1551 if (!StartShakingHands(options)) {
1552 NETSTACK_LOGE("failed to shaking hands");
1553 return false;
1554 }
1555 return true;
1556 }
1557
CreatTlsContext()1558 bool TLSSocket::TLSSocketInternal::CreatTlsContext()
1559 {
1560 tlsContextPointer_ = TLSContext::CreateConfiguration(configuration_);
1561 if (!tlsContextPointer_) {
1562 NETSTACK_LOGE("failed to create tls context pointer");
1563 return false;
1564 }
1565
1566 std::lock_guard<std::mutex> lock(mutexForSsl_);
1567 if (!(ssl_ = tlsContextPointer_->CreateSsl())) {
1568 NETSTACK_LOGE("failed to create ssl session");
1569 return false;
1570 }
1571
1572 SSL_set_fd(ssl_, socketDescriptor_);
1573 SSL_set_connect_state(ssl_);
1574 return true;
1575 }
1576
StartsWith(const std::string & s,const std::string & prefix)1577 static bool StartsWith(const std::string &s, const std::string &prefix)
1578 {
1579 return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1580 }
1581
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> dnsNames,std::vector<std::string> ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1582 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> dnsNames, std::vector<std::string> ips,
1583 const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1584 {
1585 bool valid = false;
1586 std::string reason = UNKNOW_REASON;
1587 int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1588 if (IsIP(hostName)) {
1589 auto it = find(ips.begin(), ips.end(), hostName);
1590 if (it == ips.end()) {
1591 reason = IP + hostName + " is not in the cert's list";
1592 }
1593 result = {valid, reason};
1594 return;
1595 }
1596 std::string tempHostName = "" + hostName;
1597 if (!dnsNames.empty() || index > 0) {
1598 std::vector<std::string> hostParts = SplitHostName(tempHostName);
1599 if (!dnsNames.empty()) {
1600 valid = SeekIntersection(hostParts, dnsNames);
1601 if (!valid) {
1602 reason = HOST_NAME + tempHostName + ". is not in the cert's altnames";
1603 }
1604 } else {
1605 char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1606 X509_NAME *pSubName = nullptr;
1607 int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1608 if (len > 0) {
1609 std::vector<std::string> commonNameVec;
1610 commonNameVec.emplace_back(commonNameBuf);
1611 valid = SeekIntersection(hostParts, commonNameVec);
1612 if (!valid) {
1613 reason = HOST_NAME + tempHostName + ". is not cert's CN";
1614 }
1615 }
1616 }
1617 result = {valid, reason};
1618 return;
1619 }
1620 reason = "Cert does not contain a DNS name";
1621 result = {valid, reason};
1622 }
1623
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1624 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName,
1625 const X509 *x509Certificates)
1626 {
1627 X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1628 if (!subjectName) {
1629 return "subject name is null";
1630 }
1631 char subNameBuf[BUF_SIZE] = {0};
1632 X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1633
1634 int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1635 if (index < 0) {
1636 return "X509 get ext nid error";
1637 }
1638 X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1639 if (ext == nullptr) {
1640 return "X509 get ext error";
1641 }
1642 ASN1_OBJECT *obj = nullptr;
1643 obj = X509_EXTENSION_get_object(ext);
1644 char subAltNameBuf[BUF_SIZE] = {0};
1645 OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1646 NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1647
1648 return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1649 }
1650
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1651 std::string TLSSocket::TLSSocketInternal::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1652 const X509 *x509Certificates)
1653 {
1654 ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1655 if (!extData) {
1656 NETSTACK_LOGE("extData is nullptr");
1657 return "";
1658 }
1659
1660 std::string altNames = reinterpret_cast<char *>(extData->data);
1661 std::string hostname = " " + hostName;
1662 BIO *bio = BIO_new(BIO_s_file());
1663 if (!bio) {
1664 return "bio is null";
1665 }
1666 BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1667 ASN1_STRING_print(bio, extData);
1668 std::vector<std::string> dnsNames = {};
1669 std::vector<std::string> ips = {};
1670 constexpr int DNS_NAME_IDX = 4;
1671 constexpr int IP_NAME_IDX = 11;
1672 if (!altNames.empty()) {
1673 std::vector<std::string> splitAltNames;
1674 if (altNames.find('\"') != std::string::npos) {
1675 splitAltNames = SplitEscapedAltNames(altNames);
1676 } else {
1677 splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1678 }
1679 for (auto const &iter : splitAltNames) {
1680 if (StartsWith(iter, DNS)) {
1681 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1682 } else if (StartsWith(iter, IP_ADDRESS)) {
1683 ips.push_back(iter.substr(IP_NAME_IDX));
1684 }
1685 }
1686 }
1687 std::tuple<bool, std::string> result;
1688 CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1689 if (!std::get<0>(result)) {
1690 return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1691 }
1692 return HOST_NAME + hostname + ". is cert's CN";
1693 }
1694
LoadCaCertFromMemory(X509_STORE * store,const std::string & pemCerts)1695 static void LoadCaCertFromMemory(X509_STORE *store, const std::string &pemCerts)
1696 {
1697 if (!store || pemCerts.empty() || pemCerts.size() > static_cast<size_t>(INT_MAX)) {
1698 return;
1699 }
1700
1701 auto cbio = BIO_new_mem_buf(pemCerts.data(), static_cast<int>(pemCerts.size()));
1702 if (!cbio) {
1703 return;
1704 }
1705
1706 auto inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr);
1707 if (!inf) {
1708 BIO_free(cbio);
1709 return;
1710 }
1711
1712 /* add each entry from PEM file to x509_store */
1713 for (int i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); ++i) {
1714 auto itmp = sk_X509_INFO_value(inf, i);
1715 if (!itmp) {
1716 continue;
1717 }
1718 if (itmp->x509) {
1719 X509_STORE_add_cert(store, itmp->x509);
1720 }
1721 if (itmp->crl) {
1722 X509_STORE_add_crl(store, itmp->crl);
1723 }
1724 }
1725
1726 sk_X509_INFO_pop_free(inf, X509_INFO_free);
1727 BIO_free(cbio);
1728 }
1729
X509_to_PEM(X509 * cert)1730 static std::string X509_to_PEM(X509 *cert)
1731 {
1732 if (!cert) {
1733 return {};
1734 }
1735 BIO *bio = BIO_new(BIO_s_mem());
1736 if (!bio) {
1737 return {};
1738 }
1739 if (!PEM_write_bio_X509(bio, cert)) {
1740 BIO_free(bio);
1741 return {};
1742 }
1743
1744 char *data = nullptr;
1745 auto pemStringLength = BIO_get_mem_data(bio, &data);
1746 if (!data) {
1747 BIO_free(bio);
1748 return {};
1749 }
1750 std::string certificateInPEM(data, pemStringLength);
1751 BIO_free(bio);
1752 return certificateInPEM;
1753 }
1754
CacheCertificates(const std::string & hostName,SSL * ssl)1755 static void CacheCertificates(const std::string &hostName, SSL *ssl)
1756 {
1757 if (!ssl || hostName.empty()) {
1758 return;
1759 }
1760 auto certificatesStack = SSL_get_peer_cert_chain(ssl);
1761 if (!certificatesStack) {
1762 return;
1763 }
1764 auto numCertificates = sk_X509_num(certificatesStack);
1765 for (auto i = 0; i < numCertificates; ++i) {
1766 auto cert = sk_X509_value(certificatesStack, i);
1767 auto certificateInPEM = X509_to_PEM(cert);
1768 if (!certificateInPEM.empty()) {
1769 CaCertCache::GetInstance().Set(hostName, certificateInPEM);
1770 }
1771 }
1772 }
1773
LoadCachedCaCert(const std::string & hostName,SSL * ssl)1774 static void LoadCachedCaCert(const std::string &hostName, SSL *ssl)
1775 {
1776 if (!ssl) {
1777 return;
1778 }
1779 auto cachedPem = CaCertCache::GetInstance().Get(hostName);
1780 auto sslCtx = SSL_get_SSL_CTX(ssl);
1781 if (!sslCtx) {
1782 return;
1783 }
1784 auto x509Store = SSL_CTX_get_cert_store(sslCtx);
1785 if (!x509Store) {
1786 return;
1787 }
1788 for (const auto &pem : cachedPem) {
1789 LoadCaCertFromMemory(x509Store, pem);
1790 }
1791 }
1792
StartShakingHands(const TLSConnectOptions & options)1793 bool TLSSocket::TLSSocketInternal::StartShakingHands(const TLSConnectOptions &options)
1794 {
1795 {
1796 std::lock_guard<std::mutex> lock(mutexForSsl_);
1797 if (!ssl_) {
1798 NETSTACK_LOGE("ssl is null");
1799 return false;
1800 }
1801
1802 auto hostName = options.GetHostName();
1803 // indicates hostName is not ip address
1804 if (hostName != options.GetNetAddress().GetAddress()) {
1805 LoadCachedCaCert(hostName, ssl_);
1806 }
1807
1808 int result = SSL_connect(ssl_);
1809 if (result == -1) {
1810 char err[MAX_ERR_LEN] = {0};
1811 auto code = ERR_get_error();
1812 ERR_error_string_n(code, err, MAX_ERR_LEN);
1813 int errorStatus = TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl_, SSL_RET_CODE);
1814 NETSTACK_LOGE("SSLConnect fail %{public}d, error: %{public}s errno: %{public}d ERR_get_error %{public}s",
1815 errorStatus, MakeSSLErrorString(errorStatus).c_str(), errno, err);
1816 return false;
1817 }
1818
1819 // indicates hostName is not ip address
1820 if (hostName != options.GetNetAddress().GetAddress()) {
1821 CacheCertificates(hostName, ssl_);
1822 }
1823
1824 std::string list = SSL_get_cipher_list(ssl_, 0);
1825 NETSTACK_LOGI("cipher_list: %{public}s, Version: %{public}s, Cipher: %{public}s", list.c_str(),
1826 SSL_get_version(ssl_), SSL_get_cipher(ssl_));
1827 configuration_.SetCipherSuite(list);
1828 }
1829 if (!SetSharedSigals()) {
1830 NETSTACK_LOGE("Failed to set sharedSigalgs");
1831 }
1832 if (!GetRemoteCertificateFromPeer()) {
1833 NETSTACK_LOGE("Failed to get remote certificate");
1834 }
1835 if (!peerX509_) {
1836 NETSTACK_LOGE("peer x509Certificates is null");
1837 return false;
1838 }
1839 if (!SetRemoteCertRawData()) {
1840 NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1841 }
1842 CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1843 if (!checkServerIdentity) {
1844 CheckServerIdentityLegal(hostName_, peerX509_);
1845 } else {
1846 checkServerIdentity(hostName_, {remoteCert_});
1847 }
1848 return true;
1849 }
1850
GetRemoteCertificateFromPeer()1851 bool TLSSocket::TLSSocketInternal::GetRemoteCertificateFromPeer()
1852 {
1853 peerX509_ = SSL_get_peer_certificate(ssl_);
1854 if (peerX509_ == nullptr) {
1855 int resErr = ConvertSSLError();
1856 NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1857 MakeSSLErrorString(resErr).c_str());
1858 return false;
1859 }
1860 BIO *bio = BIO_new(BIO_s_mem());
1861 if (!bio) {
1862 NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1863 return false;
1864 }
1865 X509_print(bio, peerX509_);
1866 char data[REMOTE_CERT_LEN] = {0};
1867 if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1868 NETSTACK_LOGE("BIO_read function returns error");
1869 BIO_free(bio);
1870 return false;
1871 }
1872 BIO_free(bio);
1873 remoteCert_ = std::string(data);
1874 return true;
1875 }
1876
SetRemoteCertRawData()1877 bool TLSSocket::TLSSocketInternal::SetRemoteCertRawData()
1878 {
1879 if (peerX509_ == nullptr) {
1880 NETSTACK_LOGE("peerX509 is null");
1881 return false;
1882 }
1883 int32_t length = i2d_X509(peerX509_, nullptr);
1884 if (length <= 0) {
1885 NETSTACK_LOGE("Failed to convert peerX509 to der format");
1886 return false;
1887 }
1888 unsigned char *der = nullptr;
1889 (void)i2d_X509(peerX509_, &der);
1890 SecureData data(der, length);
1891 remoteRawData_.data = data;
1892 OPENSSL_free(der);
1893 remoteRawData_.encodingFormat = DER;
1894 return true;
1895 }
1896
GetRemoteCertRawData() const1897 const X509CertRawData &TLSSocket::TLSSocketInternal::GetRemoteCertRawData() const
1898 {
1899 return remoteRawData_;
1900 }
1901
GetSSL()1902 ssl_st *TLSSocket::TLSSocketInternal::GetSSL()
1903 {
1904 std::lock_guard<std::mutex> lock(mutexForSsl_);
1905 return ssl_;
1906 }
1907 } // namespace TlsSocket
1908 } // namespace NetStack
1909 } // namespace OHOS
1910