1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tls_socket_server.h"
17 
18 #include <chrono>
19 #include <memory>
20 #include <netinet/tcp.h>
21 #include <numeric>
22 #include <openssl/err.h>
23 #include <openssl/ssl.h>
24 
25 #include <regex>
26 #include <securec.h>
27 #include <sys/ioctl.h>
28 
29 #include "base_context.h"
30 #include "netstack_common_utils.h"
31 #include "netstack_log.h"
32 #include "tls.h"
33 
34 namespace OHOS {
35 namespace NetStack {
36 namespace TlsSocketServer {
37 #if UNITTEST
38 #else
39 namespace {
40 #endif // UNITTEST
41 constexpr size_t MAX_ERR_LENGTH = 1024;
42 
43 constexpr int SSL_RET_CODE = 0;
44 
45 constexpr int BUF_SIZE = 2048;
46 constexpr int POLL_WAIT_TIME = 2000;
47 constexpr int OFFSET = 2;
48 constexpr int SSL_ERROR_RETURN = -1;
49 constexpr int REMOTE_CERT_LEN = 8192;
50 constexpr int COMMON_NAME_BUF_SIZE = 256;
51 constexpr int LISETEN_COUNT = 516;
52 constexpr const char *SPLIT_HOST_NAME = ".";
53 constexpr const char *SPLIT_ALT_NAMES = ",";
54 constexpr const char *DNS = "DNS:";
55 constexpr const char *HOST_NAME = "hostname: ";
56 constexpr const char *IP_ADDRESS = "IP Address:";
57 constexpr const char *SIGN_NID_RSA = "RSA+";
58 constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+";
59 constexpr const char *SIGN_NID_DSA = "DSA+";
60 constexpr const char *SIGN_NID_ECDSA = "ECDSA+";
61 constexpr const char *SIGN_NID_ED = "Ed25519+";
62 constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+";
63 constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+";
64 constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL";
65 constexpr const char *SIGN_NID_UNDEF = "UNDEF";
66 constexpr const char *OPERATOR_PLUS_SIGN = "+";
67 constexpr const char *UNKNOW_REASON = "Unknown reason";
68 constexpr const char *IP = "IP: ";
69 static constexpr const char *TLS_SOCKET_SERVER_READ = "OS_NET_TSAccRD";
70 const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"};
71 const std::regex PATTERN{
72     "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|"
73     "2[0-4][0-9]|[01]?[0-9][0-9]?)"};
74 int g_userCounter = 0;
75 
IsIP(const std::string & ip)76 bool IsIP(const std::string &ip)
77 {
78     std::regex pattern(PATTERN);
79     std::smatch res;
80     return regex_match(ip, res, pattern);
81 }
82 
SplitHostName(std::string & hostName)83 std::vector<std::string> SplitHostName(std::string &hostName)
84 {
85     transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower);
86     return CommonUtils::Split(hostName, SPLIT_HOST_NAME);
87 }
88 
SeekIntersection(std::vector<std::string> & vecA,std::vector<std::string> & vecB)89 bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB)
90 {
91     std::vector<std::string> result;
92     set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin()));
93     return !result.empty();
94 }
95 
ConvertErrno()96 int ConvertErrno()
97 {
98     return TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE + errno;
99 }
100 
ConvertSSLError(ssl_st * ssl)101 int ConvertSSLError(ssl_st *ssl)
102 {
103     if (!ssl) {
104         return TlsSocket::TLS_ERR_SSL_NULL;
105     }
106     return TlsSocket::TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE);
107 }
108 
MakeErrnoString()109 std::string MakeErrnoString()
110 {
111     return strerror(errno);
112 }
113 
MakeSSLErrorString(int error)114 std::string MakeSSLErrorString(int error)
115 {
116     char err[MAX_ERR_LENGTH] = {0};
117     ERR_error_string_n(error - TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err));
118     return err;
119 }
SplitEscapedAltNames(std::string & altNames)120 std::vector<std::string> SplitEscapedAltNames(std::string &altNames)
121 {
122     std::vector<std::string> result;
123     std::string currentToken;
124     size_t offset = 0;
125     while (offset != altNames.length()) {
126         auto nextSep = altNames.find_first_of(", ");
127         auto nextQuote = altNames.find_first_of('\"');
128         if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) {
129             currentToken += altNames.substr(offset, nextQuote);
130             std::regex jsonStringPattern(JSON_STRING_PATTERN);
131             std::smatch match;
132             std::string altNameSubStr = altNames.substr(nextQuote);
133             bool ret = regex_match(altNameSubStr, match, jsonStringPattern);
134             if (!ret) {
135                 return {""};
136             }
137             currentToken += result[0];
138             offset = nextQuote + result[0].length();
139         } else if (nextSep != std::string::npos) {
140             currentToken += altNames.substr(offset, nextSep);
141             result.push_back(currentToken);
142             currentToken = "";
143             offset = nextSep + OFFSET;
144         } else {
145             currentToken += altNames.substr(offset);
146             offset = altNames.length();
147         }
148     }
149     result.push_back(currentToken);
150     return result;
151 }
152 #if UNITTEST
153 #else
154 } // namespace
155 #endif
156 
SetSocket(const int & socketFd)157 void TLSServerSendOptions::SetSocket(const int &socketFd)
158 {
159     socketFd_ = socketFd;
160 }
161 
SetSendData(const std::string & data)162 void TLSServerSendOptions::SetSendData(const std::string &data)
163 {
164     data_ = data;
165 }
166 
GetSocket() const167 const int &TLSServerSendOptions::GetSocket() const
168 {
169     return socketFd_;
170 }
171 
GetSendData() const172 const std::string &TLSServerSendOptions::GetSendData() const
173 {
174     return data_;
175 }
176 
~TLSSocketServer()177 TLSSocketServer::~TLSSocketServer()
178 {
179     isRunning_ = false;
180     clientIdConnections_.clear();
181 
182     if (listenSocketFd_ != -1) {
183         shutdown(listenSocketFd_, SHUT_RDWR);
184         close(listenSocketFd_);
185         listenSocketFd_ = -1;
186     }
187 }
188 
Listen(const TlsSocket::TLSConnectOptions & tlsListenOptions,const ListenCallback & callback)189 void TLSSocketServer::Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback)
190 {
191     if (!CommonUtils::HasInternetPermission()) {
192         CallListenCallback(PERMISSION_DENIED_CODE, callback);
193         return;
194     }
195     NETSTACK_LOGE("Listen 1 %{public}d", listenSocketFd_);
196     if (listenSocketFd_ >= 0) {
197         CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
198         return;
199     }
200     NETSTACK_LOGE("Listen 2 %{public}d", listenSocketFd_);
201     if (ExecBind(tlsListenOptions.GetNetAddress(), callback)) {
202         NETSTACK_LOGE("Listen 3 %{public}d", listenSocketFd_);
203         ExecAccept(tlsListenOptions, callback);
204     } else {
205         shutdown(listenSocketFd_, SHUT_RDWR);
206         close(listenSocketFd_);
207         listenSocketFd_ = -1;
208     }
209 
210     PollThread(tlsListenOptions);
211 }
212 
ExecBind(const Socket::NetAddress & address,const ListenCallback & callback)213 bool TLSSocketServer::ExecBind(const Socket::NetAddress &address, const ListenCallback &callback)
214 {
215     MakeIpSocket(address.GetSaFamily());
216     if (listenSocketFd_ < 0) {
217         int resErr = ConvertErrno();
218         NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str());
219         CallOnErrorCallback(resErr, MakeErrnoString());
220         CallListenCallback(resErr, callback);
221         return false;
222     }
223     sockaddr_in addr4 = {0};
224     sockaddr_in6 addr6 = {0};
225     sockaddr *addr = nullptr;
226     socklen_t len;
227     GetAddr(address, &addr4, &addr6, &addr, &len);
228     if (addr == nullptr) {
229         NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid");
230         CallOnErrorCallback(-1, "Address Is Invalid");
231         CallListenCallback(ConvertErrno(), callback);
232         return false;
233     }
234     if (bind(listenSocketFd_, addr, len) < 0) {
235         if (errno != EADDRINUSE) {
236             NETSTACK_LOGE("bind error is %{public}s %{public}d", strerror(errno), errno);
237             CallOnErrorCallback(-1, "Address binding failed");
238             CallListenCallback(ConvertErrno(), callback);
239             return false;
240         }
241         if (addr->sa_family == AF_INET) {
242             NETSTACK_LOGI("distribute a random port");
243             addr4.sin_port = 0; /* distribute a random port */
244         } else if (addr->sa_family == AF_INET6) {
245             NETSTACK_LOGI("distribute a random port");
246             addr6.sin6_port = 0; /* distribute a random port */
247         }
248         if (bind(listenSocketFd_, addr, len) < 0) {
249             NETSTACK_LOGE("rebind error is %{public}s %{public}d", strerror(errno), errno);
250             CallOnErrorCallback(-1, "Duplicate binding address failed");
251             CallListenCallback(ConvertErrno(), callback);
252             return false;
253         }
254         NETSTACK_LOGI("rebind success");
255     }
256     NETSTACK_LOGI("bind success");
257     address_ = address;
258     return true;
259 }
260 
ExecAccept(const TlsSocket::TLSConnectOptions & tlsAcceptOptions,const ListenCallback & callback)261 void TLSSocketServer::ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback)
262 {
263     if (listenSocketFd_ < 0) {
264         int resErr = ConvertErrno();
265         NETSTACK_LOGE("accept error is %{public}s %{public}d", MakeErrnoString().c_str(), errno);
266         CallOnErrorCallback(resErr, MakeErrnoString());
267         callback(resErr);
268         return;
269     }
270     SetLocalTlsConfiguration(tlsAcceptOptions);
271     int ret = 0;
272 
273     NETSTACK_LOGE(
274         "accept error is listenSocketFd_=  %{public}d LISETEN_COUNT =%{public}d .GetVerifyMode()  = %{public}d ",
275         listenSocketFd_, LISETEN_COUNT, tlsAcceptOptions.GetTlsSecureOptions().GetVerifyMode());
276     ret = listen(listenSocketFd_, LISETEN_COUNT);
277     if (ret < 0) {
278         int resErr = ConvertErrno();
279         NETSTACK_LOGE("tcp server listen error");
280         CallOnErrorCallback(resErr, MakeErrnoString());
281         callback(resErr);
282         return;
283     }
284     CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
285 }
286 
Send(const TLSServerSendOptions & data,const TlsSocket::SendCallback & callback)287 bool TLSSocketServer::Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback)
288 {
289     int socketFd = data.GetSocket();
290     std::string info = data.GetSendData();
291 
292     auto connect_iterator = clientIdConnections_.find(socketFd);
293     if (connect_iterator == clientIdConnections_.end()) {
294         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
295         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
296         return false;
297     }
298     auto connect = connect_iterator->second;
299     auto res = connect->Send(info);
300     if (!res) {
301         int resErr = ConvertSSLError(connect->GetSSL());
302         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
303         CallSendCallback(resErr, callback);
304         return false;
305     }
306     CallSendCallback(TlsSocket::TLSSOCKET_SUCCESS, callback);
307     return res;
308 }
309 
CallSendCallback(int32_t err,TlsSocket::SendCallback callback)310 void TLSSocketServer::CallSendCallback(int32_t err, TlsSocket::SendCallback callback)
311 {
312     TlsSocket::SendCallback CallBackfunc = nullptr;
313     {
314         std::lock_guard<std::mutex> lock(mutex_);
315         if (callback) {
316             CallBackfunc = callback;
317         }
318     }
319 
320     if (CallBackfunc) {
321         CallBackfunc(err);
322     }
323 }
324 
Close(const int socketFd,const TlsSocket::CloseCallback & callback)325 void TLSSocketServer::Close(const int socketFd, const TlsSocket::CloseCallback &callback)
326 {
327     {
328         std::lock_guard<std::mutex> its_lock(connectMutex_);
329         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
330             if (it->first == socketFd) {
331                 auto res = it->second->Close();
332                 if (!res) {
333                     int resErr = ConvertSSLError(it->second->GetSSL());
334                     NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr);
335                     CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
336                     callback(resErr);
337                     return;
338                 }
339                 callback(TlsSocket::TLSSOCKET_SUCCESS);
340                 return;
341             } else {
342                 ++it;
343             }
344         }
345     }
346     NETSTACK_LOGE("socket = %{public}d There is no corresponding socketFd", socketFd);
347     CallOnErrorCallback(-1, "The send failed with no corresponding socketFd");
348     callback(TlsSocket::TLS_ERR_SYS_EINVAL);
349 }
350 
Stop(const TlsSocket::CloseCallback & callback)351 void TLSSocketServer::Stop(const TlsSocket::CloseCallback &callback)
352 {
353     std::lock_guard<std::mutex> its_lock(connectMutex_);
354     for (const auto &c : clientIdConnections_) {
355         c.second->Close();
356     }
357     clientIdConnections_.clear();
358     close(listenSocketFd_);
359     listenSocketFd_ = -1;
360     callback(TlsSocket::TLSSOCKET_SUCCESS);
361 }
362 
GetRemoteAddress(const int socketFd,const TlsSocket::GetRemoteAddressCallback & callback)363 void TLSSocketServer::GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback)
364 {
365     auto connect_iterator = clientIdConnections_.find(socketFd);
366     if (connect_iterator == clientIdConnections_.end()) {
367         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
368         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
369         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
370         return;
371     }
372     auto connect = connect_iterator->second;
373     auto address = connect->GetAddress();
374     callback(TlsSocket::TLSSOCKET_SUCCESS, address);
375 }
376 
GetLocalAddress(const int socketFd,const TlsSocket::GetLocalAddressCallback & callback)377 void TLSSocketServer::GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback)
378 {
379     auto connect_iterator = clientIdConnections_.find(socketFd);
380     if (connect_iterator == clientIdConnections_.end()) {
381         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
382         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
383         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
384         return;
385     }
386     auto connect = connect_iterator->second;
387     auto localAddress = connect->GetLocalAddress();
388     callback(TlsSocket::TLSSOCKET_SUCCESS, localAddress);
389 }
390 
GetState(const TlsSocket::GetStateCallback & callback)391 void TLSSocketServer::GetState(const TlsSocket::GetStateCallback &callback)
392 {
393     int opt;
394     socklen_t optLen = sizeof(int);
395     int r = getsockopt(listenSocketFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen);
396     if (r < 0) {
397         Socket::SocketStateBase state;
398         state.SetIsClose(true);
399         CallGetStateCallback(ConvertErrno(), state, callback);
400         return;
401     }
402     sockaddr sockAddr = {0};
403     socklen_t len = sizeof(sockaddr);
404     Socket::SocketStateBase state;
405     int ret = getsockname(listenSocketFd_, &sockAddr, &len);
406     state.SetIsBound(ret == 0);
407     ret = getpeername(listenSocketFd_, &sockAddr, &len);
408     if (ret != 0) {
409         NETSTACK_LOGE("getpeername failed");
410     }
411     state.SetIsConnected(GetConnectionClientCount() > 0);
412     CallGetStateCallback(TlsSocket::TLSSOCKET_SUCCESS, state, callback);
413 }
414 
CallGetStateCallback(int32_t err,const Socket::SocketStateBase & state,TlsSocket::GetStateCallback callback)415 void TLSSocketServer::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state,
416                                            TlsSocket::GetStateCallback callback)
417 {
418     TlsSocket::GetStateCallback CallBackfunc = nullptr;
419     {
420         std::lock_guard<std::mutex> lock(mutex_);
421         if (callback) {
422             CallBackfunc = callback;
423         }
424     }
425 
426     if (CallBackfunc) {
427         CallBackfunc(err, state);
428     }
429 }
SetExtraOptions(const Socket::TCPExtraOptions & tcpExtraOptions,const TlsSocket::SetExtraOptionsCallback & callback)430 bool TLSSocketServer::SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
431                                       const TlsSocket::SetExtraOptionsCallback &callback)
432 {
433     if (tcpExtraOptions.IsKeepAlive()) {
434         int keepalive = 1;
435         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
436             return false;
437         }
438     }
439 
440     if (tcpExtraOptions.IsOOBInline()) {
441         int oobInline = 1;
442         if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) {
443             return false;
444         }
445     }
446 
447     if (tcpExtraOptions.IsTCPNoDelay()) {
448         int tcpNoDelay = 1;
449         if (setsockopt(listenSocketFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) {
450             return false;
451         }
452     }
453 
454     linger soLinger = {0};
455     soLinger.l_onoff = tcpExtraOptions.socketLinger.IsOn();
456     soLinger.l_linger = (int)tcpExtraOptions.socketLinger.GetLinger();
457     if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) {
458         return false;
459     }
460 
461     return true;
462 }
463 
SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions & config)464 void TLSSocketServer::SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
465 {
466     TLSServerConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
467                                           config.GetTlsSecureOptions().GetKeyPass());
468     TLSServerConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
469     TLSServerConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
470 
471     TLSServerConfiguration_.SetVerifyMode(config.GetTlsSecureOptions().GetVerifyMode());
472 
473     const auto protocolVec = config.GetTlsSecureOptions().GetProtocolChain();
474     if (!protocolVec.empty()) {
475         TLSServerConfiguration_.SetProtocol(protocolVec);
476     }
477 }
478 
GetCertificate(const TlsSocket::GetCertificateCallback & callback)479 void TLSSocketServer::GetCertificate(const TlsSocket::GetCertificateCallback &callback)
480 {
481     const auto &cert = TLSServerConfiguration_.GetCertificate();
482     NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat);
483     if (!cert.data.Length()) {
484         CallOnErrorCallback(-1, "cert not data Length");
485         callback(-1, {});
486         return;
487     }
488     callback(TlsSocket::TLSSOCKET_SUCCESS, cert);
489 }
490 
GetRemoteCertificate(const int socketFd,const TlsSocket::GetRemoteCertificateCallback & callback)491 void TLSSocketServer::GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback)
492 {
493     auto connect_iterator = clientIdConnections_.find(socketFd);
494     if (connect_iterator == clientIdConnections_.end()) {
495         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
496         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
497         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
498         return;
499     }
500     auto connect = connect_iterator->second;
501     const auto &remoteCert = connect->GetRemoteCertRawData();
502     if (!remoteCert.data.Length()) {
503         int resErr = ConvertSSLError(connect->GetSSL());
504         NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str());
505         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
506         callback(resErr, {});
507         return;
508     }
509     callback(TlsSocket::TLSSOCKET_SUCCESS, remoteCert);
510 }
511 
GetProtocol(const TlsSocket::GetProtocolCallback & callback)512 void TLSSocketServer::GetProtocol(const TlsSocket::GetProtocolCallback &callback)
513 {
514     if (TLSServerConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
515         callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V13);
516         return;
517     }
518     callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V12);
519 }
520 
GetCipherSuite(const int socketFd,const TlsSocket::GetCipherSuiteCallback & callback)521 void TLSSocketServer::GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback)
522 {
523     auto connect_iterator = clientIdConnections_.find(socketFd);
524     if (connect_iterator == clientIdConnections_.end()) {
525         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
526         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
527         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
528         return;
529     }
530     auto connect = connect_iterator->second;
531     auto cipherSuite = connect->GetCipherSuite();
532     if (cipherSuite.empty()) {
533         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
534         int resErr = ConvertSSLError(connect->GetSSL());
535         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
536         callback(resErr, cipherSuite);
537         return;
538     }
539     callback(TlsSocket::TLSSOCKET_SUCCESS, cipherSuite);
540 }
541 
GetSignatureAlgorithms(const int socketFd,const TlsSocket::GetSignatureAlgorithmsCallback & callback)542 void TLSSocketServer::GetSignatureAlgorithms(const int socketFd,
543                                              const TlsSocket::GetSignatureAlgorithmsCallback &callback)
544 {
545     auto connect_iterator = clientIdConnections_.find(socketFd);
546     if (connect_iterator == clientIdConnections_.end()) {
547         NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd);
548         CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd");
549         callback(TlsSocket::TLS_ERR_SYS_EINVAL, {});
550         return;
551     }
552     auto connect = connect_iterator->second;
553     auto signatureAlgorithms = connect->GetSignatureAlgorithms();
554     if (signatureAlgorithms.empty()) {
555         NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno);
556         int resErr = ConvertSSLError(connect->GetSSL());
557         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
558         callback(resErr, signatureAlgorithms);
559         return;
560     }
561     callback(TlsSocket::TLSSOCKET_SUCCESS, signatureAlgorithms);
562 }
563 
OnMessage(const OnMessageCallback & onMessageCallback)564 void TLSSocketServer::Connection::OnMessage(const OnMessageCallback &onMessageCallback)
565 {
566     onMessageCallback_ = onMessageCallback;
567 }
568 
OnClose(const OnCloseCallback & onCloseCallback)569 void TLSSocketServer::Connection::OnClose(const OnCloseCallback &onCloseCallback)
570 {
571     onCloseCallback_ = onCloseCallback;
572 }
573 
OnConnect(const OnConnectCallback & onConnectCallback)574 void TLSSocketServer::OnConnect(const OnConnectCallback &onConnectCallback)
575 {
576     std::lock_guard<std::mutex> lock(mutex_);
577     onConnectCallback_ = onConnectCallback;
578 }
579 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)580 void TLSSocketServer::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
581 {
582     std::lock_guard<std::mutex> lock(mutex_);
583     onErrorCallback_ = onErrorCallback;
584 }
585 
OffMessage()586 void TLSSocketServer::Connection::OffMessage()
587 {
588     if (onMessageCallback_) {
589         onMessageCallback_ = nullptr;
590     }
591 }
592 
OffConnect()593 void TLSSocketServer::OffConnect()
594 {
595     std::lock_guard<std::mutex> lock(mutex_);
596     if (onConnectCallback_) {
597         onConnectCallback_ = nullptr;
598     }
599 }
600 
OnError(const TlsSocket::OnErrorCallback & onErrorCallback)601 void TLSSocketServer::Connection::OnError(const TlsSocket::OnErrorCallback &onErrorCallback)
602 {
603     onErrorCallback_ = onErrorCallback;
604 }
605 
OffClose()606 void TLSSocketServer::Connection::OffClose()
607 {
608     if (onCloseCallback_) {
609         onCloseCallback_ = nullptr;
610     }
611 }
612 
OffError()613 void TLSSocketServer::Connection::OffError()
614 {
615     onErrorCallback_ = nullptr;
616 }
617 
CallOnErrorCallback(int32_t err,const std::string & errString)618 void TLSSocketServer::Connection::CallOnErrorCallback(int32_t err, const std::string &errString)
619 {
620     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
621     {
622         if (onErrorCallback_) {
623             CallBackfunc = onErrorCallback_;
624         }
625     }
626 
627     if (CallBackfunc) {
628         CallBackfunc(err, errString);
629     }
630 }
OffError()631 void TLSSocketServer::OffError()
632 {
633     std::lock_guard<std::mutex> lock(mutex_);
634     if (onErrorCallback_) {
635         onErrorCallback_ = nullptr;
636     }
637 }
638 
MakeIpSocket(sa_family_t family)639 void TLSSocketServer::MakeIpSocket(sa_family_t family)
640 {
641     if (family != AF_INET && family != AF_INET6) {
642         return;
643     }
644     int sock = socket(family, SOCK_STREAM, IPPROTO_IP);
645     if (sock < 0) {
646         int resErr = ConvertErrno();
647         NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str());
648         CallOnErrorCallback(resErr, MakeErrnoString());
649         return;
650     }
651     listenSocketFd_ = sock;
652 }
653 
CallOnErrorCallback(int32_t err,const std::string & errString)654 void TLSSocketServer::CallOnErrorCallback(int32_t err, const std::string &errString)
655 {
656     TlsSocket::OnErrorCallback CallBackfunc = nullptr;
657     {
658         std::lock_guard<std::mutex> lock(mutex_);
659         if (onErrorCallback_) {
660             CallBackfunc = onErrorCallback_;
661         }
662     }
663 
664     if (CallBackfunc) {
665         CallBackfunc(err, errString);
666     }
667 }
GetAddr(const Socket::NetAddress & address,sockaddr_in * addr4,sockaddr_in6 * addr6,sockaddr ** addr,socklen_t * len)668 void TLSSocketServer::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6,
669                               sockaddr **addr, socklen_t *len)
670 {
671     if (!addr6 || !addr4 || !len) {
672         return;
673     }
674     sa_family_t family = address.GetSaFamily();
675     if (family == AF_INET) {
676         addr4->sin_family = AF_INET;
677         addr4->sin_port = htons(address.GetPort());
678         addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
679         *addr = reinterpret_cast<sockaddr *>(addr4);
680         *len = sizeof(sockaddr_in);
681     } else if (family == AF_INET6) {
682         addr6->sin6_family = AF_INET6;
683         addr6->sin6_port = htons(address.GetPort());
684         inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr);
685         *addr = reinterpret_cast<sockaddr *>(addr6);
686         *len = sizeof(sockaddr_in6);
687     }
688 }
689 
GetListenSocketFd()690 int TLSSocketServer::GetListenSocketFd()
691 {
692     return listenSocketFd_;
693 }
694 
SetLocalAddress(const Socket::NetAddress & address)695 void TLSSocketServer::SetLocalAddress(const Socket::NetAddress &address)
696 {
697     localAddress_ = address;
698 }
699 
GetLocalAddress()700 Socket::NetAddress TLSSocketServer::GetLocalAddress()
701 {
702     return localAddress_;
703 }
704 
GetConnectionByClientID(int clientid)705 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientID(int clientid)
706 {
707     std::shared_ptr<Connection> ptrConnection = nullptr;
708 
709     auto it = clientIdConnections_.find(clientid);
710     if (it != clientIdConnections_.end()) {
711         ptrConnection = it->second;
712     }
713 
714     return ptrConnection;
715 }
716 
GetConnectionClientCount()717 int TLSSocketServer::GetConnectionClientCount()
718 {
719     return g_userCounter;
720 }
721 
CallListenCallback(int32_t err,ListenCallback callback)722 void TLSSocketServer::CallListenCallback(int32_t err, ListenCallback callback)
723 {
724     ListenCallback CallBackfunc = nullptr;
725     {
726         std::lock_guard<std::mutex> lock(mutex_);
727         if (callback) {
728             CallBackfunc = callback;
729         }
730     }
731 
732     if (CallBackfunc) {
733         CallBackfunc(err);
734     }
735 }
736 
SetAddress(const Socket::NetAddress address)737 void TLSSocketServer::Connection::SetAddress(const Socket::NetAddress address)
738 {
739     address_ = address;
740 }
741 
SetLocalAddress(const Socket::NetAddress address)742 void TLSSocketServer::Connection::SetLocalAddress(const Socket::NetAddress address)
743 {
744     localAddress_ = address;
745 }
746 
GetRemoteCertRawData() const747 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetRemoteCertRawData() const
748 {
749     return remoteRawData_;
750 }
751 
~Connection()752 TLSSocketServer::Connection::~Connection()
753 {
754     Close();
755 }
756 
TlsAcceptToHost(int sock,const TlsSocket::TLSConnectOptions & options)757 bool TLSSocketServer::Connection::TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options)
758 {
759     SetTlsConfiguration(options);
760     std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite();
761     if (!cipherSuite.empty()) {
762         connectionConfiguration_.SetCipherSuite(cipherSuite);
763     }
764     std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms();
765     if (!signatureAlgorithms.empty()) {
766         connectionConfiguration_.SetSignatureAlgorithms(signatureAlgorithms);
767     }
768     const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain();
769     if (!protocolVec.empty()) {
770         connectionConfiguration_.SetProtocol(protocolVec);
771     }
772     connectionConfiguration_.SetVerifyMode(options.GetTlsSecureOptions().GetVerifyMode());
773     socketFd_ = sock;
774     return StartTlsAccept(options);
775 }
776 
SetTlsConfiguration(const TlsSocket::TLSConnectOptions & config)777 void TLSSocketServer::Connection::SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config)
778 {
779     connectionConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(),
780                                            config.GetTlsSecureOptions().GetKeyPass());
781     connectionConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert());
782     connectionConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain());
783     connectionConfiguration_.SetNetAddress(config.GetNetAddress());
784 }
785 
Send(const std::string & data)786 bool TLSSocketServer::Connection::Send(const std::string &data)
787 {
788     if (!ssl_) {
789         NETSTACK_LOGE("ssl is null");
790         return false;
791     }
792     if (data.empty()) {
793         NETSTACK_LOGI("data is empty");
794         return true;
795     }
796     int len = SSL_write(ssl_, data.c_str(), data.length());
797     if (len < 0) {
798         int resErr = ConvertSSLError(GetSSL());
799         NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'",
800                       data.c_str(), resErr, MakeSSLErrorString(resErr).c_str());
801         return false;
802     }
803     NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len);
804     return true;
805 }
806 
Recv(char * buffer,int maxBufferSize)807 int TLSSocketServer::Connection::Recv(char *buffer, int maxBufferSize)
808 {
809     if (!ssl_) {
810         NETSTACK_LOGE("ssl is null");
811         return SSL_ERROR_RETURN;
812     }
813     return SSL_read(ssl_, buffer, maxBufferSize);
814 }
815 
Close()816 bool TLSSocketServer::Connection::Close()
817 {
818     if (!ssl_) {
819         NETSTACK_LOGE("ssl is null");
820         return false;
821     }
822     int result = SSL_shutdown(ssl_);
823     if (result < 0) {
824         int resErr = ConvertSSLError(GetSSL());
825         NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr,
826                       MakeSSLErrorString(resErr).c_str());
827     }
828     SSL_free(ssl_);
829     ssl_ = nullptr;
830     if (socketFd_ != -1) {
831         shutdown(socketFd_, SHUT_RDWR);
832         close(socketFd_);
833         socketFd_ = -1;
834     }
835     if (!tlsContextServerPointer_) {
836         NETSTACK_LOGE("Tls context pointer is null");
837         return false;
838     }
839     tlsContextServerPointer_->CloseCtx();
840     return true;
841 }
842 
SetAlpnProtocols(const std::vector<std::string> & alpnProtocols)843 bool TLSSocketServer::Connection::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols)
844 {
845     if (!ssl_) {
846         NETSTACK_LOGE("ssl is null");
847         return false;
848     }
849     size_t pos = 0;
850     size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0),
851                                  [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); });
852     auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len);
853     for (const auto &str : alpnProtocols) {
854         len = str.length();
855         result[pos++] = len;
856         if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) {
857             NETSTACK_LOGE("strcpy_s failed");
858             return false;
859         }
860         pos += len;
861     }
862     result[pos] = '\0';
863 
864     NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get());
865     if (SSL_set_alpn_protos(ssl_, result.get(), pos)) {
866         int resErr = ConvertSSLError(GetSSL());
867         NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr,
868                       MakeSSLErrorString(resErr).c_str());
869         return false;
870     }
871     return true;
872 }
873 
MakeRemoteInfo(Socket::SocketRemoteInfo & remoteInfo)874 void TLSSocketServer::Connection::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo)
875 {
876     remoteInfo.SetAddress(address_.GetAddress());
877     remoteInfo.SetPort(address_.GetPort());
878     remoteInfo.SetFamily(address_.GetSaFamily());
879 }
880 
GetTlsConfiguration() const881 TlsSocket::TLSConfiguration TLSSocketServer::Connection::GetTlsConfiguration() const
882 {
883     return connectionConfiguration_;
884 }
885 
GetCipherSuite() const886 std::vector<std::string> TLSSocketServer::Connection::GetCipherSuite() const
887 {
888     if (!ssl_) {
889         NETSTACK_LOGE("ssl in null");
890         return {};
891     }
892     STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_);
893     if (!sk) {
894         NETSTACK_LOGE("get ciphers failed");
895         return {};
896     }
897     TlsSocket::CipherSuite cipherSuite;
898     std::vector<std::string> cipherSuiteVec;
899     for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
900         const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i);
901         cipherSuite.cipherName_ = SSL_CIPHER_get_name(c);
902         cipherSuiteVec.push_back(cipherSuite.cipherName_);
903     }
904     return cipherSuiteVec;
905 }
906 
GetRemoteCertificate() const907 std::string TLSSocketServer::Connection::GetRemoteCertificate() const
908 {
909     return remoteCert_;
910 }
911 
GetCertificate() const912 const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetCertificate() const
913 {
914     return connectionConfiguration_.GetCertificate();
915 }
916 
GetSignatureAlgorithms() const917 std::vector<std::string> TLSSocketServer::Connection::GetSignatureAlgorithms() const
918 {
919     return signatureAlgorithms_;
920 }
921 
GetProtocol() const922 std::string TLSSocketServer::Connection::GetProtocol() const
923 {
924     if (!ssl_) {
925         NETSTACK_LOGE("ssl in null");
926         return PROTOCOL_UNKNOW;
927     }
928     if (connectionConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) {
929         return TlsSocket::PROTOCOL_TLS_V13;
930     }
931     return TlsSocket::PROTOCOL_TLS_V12;
932 }
933 
SetSharedSigals()934 bool TLSSocketServer::Connection::SetSharedSigals()
935 {
936     if (!ssl_) {
937         NETSTACK_LOGE("ssl is null");
938         return false;
939     }
940     int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr);
941     if (!number) {
942         NETSTACK_LOGE("SSL_get_shared_sigalgs return value error");
943         return false;
944     }
945     for (int i = 0; i < number; i++) {
946         int hash_nid;
947         int sign_nid;
948         std::string sig_with_md;
949         SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr);
950         switch (sign_nid) {
951             case EVP_PKEY_RSA:
952                 sig_with_md = SIGN_NID_RSA;
953                 break;
954             case EVP_PKEY_RSA_PSS:
955                 sig_with_md = SIGN_NID_RSA_PSS;
956                 break;
957             case EVP_PKEY_DSA:
958                 sig_with_md = SIGN_NID_DSA;
959                 break;
960             case EVP_PKEY_EC:
961                 sig_with_md = SIGN_NID_ECDSA;
962                 break;
963             case NID_ED25519:
964                 sig_with_md = SIGN_NID_ED;
965                 break;
966             case NID_ED448:
967                 sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT;
968                 break;
969             default:
970                 const char *sn = OBJ_nid2sn(sign_nid);
971                 sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD;
972         }
973         const char *sn_hash = OBJ_nid2sn(hash_nid);
974         sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF;
975         signatureAlgorithms_.push_back(sig_with_md);
976     }
977     return true;
978 }
979 
GetSSL() const980 ssl_st *TLSSocketServer::Connection::GetSSL() const
981 {
982     return ssl_;
983 }
984 
GetAddress() const985 Socket::NetAddress TLSSocketServer::Connection::GetAddress() const
986 {
987     return address_;
988 }
989 
GetLocalAddress() const990 Socket::NetAddress TLSSocketServer::Connection::GetLocalAddress() const
991 {
992     return localAddress_;
993 }
994 
GetSocketFd() const995 int TLSSocketServer::Connection::GetSocketFd() const
996 {
997     return socketFd_;
998 }
999 
GetEventManager() const1000 std::shared_ptr<EventManager> TLSSocketServer::Connection::GetEventManager() const
1001 {
1002     return eventManager_;
1003 }
1004 
SetEventManager(std::shared_ptr<EventManager> eventManager)1005 void TLSSocketServer::Connection::SetEventManager(std::shared_ptr<EventManager> eventManager)
1006 {
1007     eventManager_ = eventManager;
1008 }
1009 
SetClientID(int32_t clientID)1010 void TLSSocketServer::Connection::SetClientID(int32_t clientID)
1011 {
1012     clientID_ = clientID;
1013 }
1014 
GetClientID()1015 int TLSSocketServer::Connection::GetClientID()
1016 {
1017     return clientID_;
1018 }
1019 
StartTlsAccept(const TlsSocket::TLSConnectOptions & options)1020 bool TLSSocketServer::Connection::StartTlsAccept(const TlsSocket::TLSConnectOptions &options)
1021 {
1022     if (!CreatTlsContext()) {
1023         NETSTACK_LOGE("failed to create tls context");
1024         return false;
1025     }
1026     if (!StartShakingHands(options)) {
1027         NETSTACK_LOGE("failed to shaking hands");
1028         return false;
1029     }
1030     return true;
1031 }
1032 
CreatTlsContext()1033 bool TLSSocketServer::Connection::CreatTlsContext()
1034 {
1035     tlsContextServerPointer_ = TlsSocket::TLSContextServer::CreateConfiguration(connectionConfiguration_);
1036     if (!tlsContextServerPointer_) {
1037         NETSTACK_LOGE("failed to create tls context pointer");
1038         return false;
1039     }
1040     if (!(ssl_ = tlsContextServerPointer_->CreateSsl())) {
1041         NETSTACK_LOGE("failed to create ssl session");
1042         return false;
1043     }
1044     SSL_set_fd(ssl_, socketFd_);
1045     SSL_set_accept_state(ssl_);
1046     return true;
1047 }
1048 
StartShakingHands(const TlsSocket::TLSConnectOptions & options)1049 bool TLSSocketServer::Connection::StartShakingHands(const TlsSocket::TLSConnectOptions &options)
1050 {
1051     if (!ssl_) {
1052         NETSTACK_LOGE("ssl is null");
1053         return false;
1054     }
1055     int result = SSL_accept(ssl_);
1056     if (result == -1) {
1057         int errorStatus = ConvertSSLError(ssl_);
1058         NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus,
1059                       MakeSSLErrorString(errorStatus).c_str());
1060         return false;
1061     }
1062 
1063     std::vector<std::string> SslProtocolVer({SSL_get_version(ssl_)});
1064     connectionConfiguration_.SetProtocol({SslProtocolVer});
1065 
1066     std::string list = SSL_get_cipher_list(ssl_, 0);
1067     NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str());
1068     connectionConfiguration_.SetCipherSuite(list);
1069     if (!SetSharedSigals()) {
1070         NETSTACK_LOGE("Failed to set sharedSigalgs");
1071     }
1072 
1073     if (!GetRemoteCertificateFromPeer()) {
1074         NETSTACK_LOGE("Failed to get remote certificate");
1075     }
1076     if (peerX509_ != nullptr) {
1077         NETSTACK_LOGE("peer x509Certificates is null");
1078 
1079         if (!SetRemoteCertRawData()) {
1080             NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data");
1081         }
1082         TlsSocket::CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity();
1083         if (!checkServerIdentity) {
1084             CheckServerIdentityLegal(hostName_, peerX509_);
1085         } else {
1086             checkServerIdentity(hostName_, {remoteCert_});
1087         }
1088     }
1089     return true;
1090 }
1091 
GetRemoteCertificateFromPeer()1092 bool TLSSocketServer::Connection::GetRemoteCertificateFromPeer()
1093 {
1094     peerX509_ = SSL_get_peer_certificate(ssl_);
1095 
1096     if (SSL_get_verify_result(ssl_) == X509_V_OK) {
1097         NETSTACK_LOGE("SSL_get_verify_result ==X509_V_OK");
1098     }
1099 
1100     if (peerX509_ == nullptr) {
1101         int resErr = ConvertSSLError(GetSSL());
1102         NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr,
1103                       MakeSSLErrorString(resErr).c_str());
1104         return false;
1105     }
1106     BIO *bio = BIO_new(BIO_s_mem());
1107     if (!bio) {
1108         NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null");
1109         return false;
1110     }
1111     X509_print(bio, peerX509_);
1112     char data[REMOTE_CERT_LEN] = {0};
1113     if (!BIO_read(bio, data, REMOTE_CERT_LEN)) {
1114         NETSTACK_LOGE("BIO_read function returns error");
1115         BIO_free(bio);
1116         return false;
1117     }
1118     BIO_free(bio);
1119     remoteCert_ = std::string(data);
1120     return true;
1121 }
1122 
SetRemoteCertRawData()1123 bool TLSSocketServer::Connection::SetRemoteCertRawData()
1124 {
1125     if (peerX509_ == nullptr) {
1126         NETSTACK_LOGE("peerX509 is null");
1127         return false;
1128     }
1129     int32_t length = i2d_X509(peerX509_, nullptr);
1130     if (length <= 0) {
1131         NETSTACK_LOGE("Failed to convert peerX509 to der format");
1132         return false;
1133     }
1134     unsigned char *der = nullptr;
1135     (void)i2d_X509(peerX509_, &der);
1136     TlsSocket::SecureData data(der, length);
1137     remoteRawData_.data = data;
1138     OPENSSL_free(der);
1139     remoteRawData_.encodingFormat = TlsSocket::EncodingFormat::DER;
1140     return true;
1141 }
1142 
StartsWith(const std::string & s,const std::string & prefix)1143 static bool StartsWith(const std::string &s, const std::string &prefix)
1144 {
1145     return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0;
1146 }
CheckIpAndDnsName(const std::string & hostName,std::vector<std::string> & dnsNames,std::vector<std::string> & ips,const X509 * x509Certificates,std::tuple<bool,std::string> & result)1147 void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> &dnsNames, std::vector<std::string> &ips,
1148                        const X509 *x509Certificates, std::tuple<bool, std::string> &result)
1149 {
1150     bool valid = false;
1151     std::string reason = UNKNOW_REASON;
1152     int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1);
1153     if (IsIP(hostName)) {
1154         auto it = find(ips.begin(), ips.end(), hostName);
1155         if (it == ips.end()) {
1156             reason = IP + hostName + " is not in the cert's list";
1157         }
1158         result = {valid, reason};
1159         return;
1160     }
1161     std::string tempHostName = "" + hostName;
1162     if (!dnsNames.empty() || index > 0) {
1163         std::vector<std::string> hostParts = SplitHostName(tempHostName);
1164         std::string tmpStr = "";
1165         if (!dnsNames.empty()) {
1166             valid = SeekIntersection(hostParts, dnsNames);
1167             tmpStr = ". is not in the cert's altnames";
1168         } else {
1169             char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
1170             X509_NAME *pSubName = nullptr;
1171             int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
1172             if (len > 0) {
1173                 std::vector<std::string> commonNameVec;
1174                 commonNameVec.emplace_back(commonNameBuf);
1175                 valid = SeekIntersection(hostParts, commonNameVec);
1176                 tmpStr = ". is not cert's CN";
1177             }
1178         }
1179         if (!valid) {
1180             reason = HOST_NAME + tempHostName + tmpStr;
1181         }
1182 
1183         result = {valid, reason};
1184         return;
1185     }
1186     reason = "Cert does not contain a DNS name";
1187     result = {valid, reason};
1188 }
1189 
CheckServerIdentityLegal(const std::string & hostName,const X509 * x509Certificates)1190 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName,
1191                                                                   const X509 *x509Certificates)
1192 {
1193     X509_NAME *subjectName = X509_get_subject_name(x509Certificates);
1194     if (!subjectName) {
1195         return "subject name is null";
1196     }
1197     char subNameBuf[BUF_SIZE] = {0};
1198     X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE);
1199     int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1);
1200     if (index < 0) {
1201         return "X509 get ext nid error";
1202     }
1203     X509_EXTENSION *ext = X509_get_ext(x509Certificates, index);
1204     if (ext == nullptr) {
1205         return "X509 get ext error";
1206     }
1207     ASN1_OBJECT *obj = nullptr;
1208     obj = X509_EXTENSION_get_object(ext);
1209     char subAltNameBuf[BUF_SIZE] = {0};
1210     OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0);
1211     NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf);
1212 
1213     return CheckServerIdentityLegal(hostName, ext, x509Certificates);
1214 }
1215 
CheckServerIdentityLegal(const std::string & hostName,X509_EXTENSION * ext,const X509 * x509Certificates)1216 std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
1217                                                                   const X509 *x509Certificates)
1218 {
1219     ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
1220     if (!extData) {
1221         NETSTACK_LOGE("extData is nullptr");
1222         return "";
1223     }
1224 
1225     std::string altNames = reinterpret_cast<char *>(extData->data);
1226     std::string hostname = "" + hostName;
1227     BIO *bio = BIO_new(BIO_s_file());
1228     if (!bio) {
1229         return "bio is null";
1230     }
1231     BIO_set_fp(bio, stdout, BIO_NOCLOSE);
1232     ASN1_STRING_print(bio, extData);
1233     std::vector<std::string> dnsNames = {};
1234     std::vector<std::string> ips = {};
1235     constexpr int DNS_NAME_IDX = 4;
1236     constexpr int IP_NAME_IDX = 11;
1237     if (!altNames.empty()) {
1238         std::vector<std::string> splitAltNames;
1239         if (altNames.find('\"') != std::string::npos) {
1240             splitAltNames = SplitEscapedAltNames(altNames);
1241         } else {
1242             splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES);
1243         }
1244         for (auto const &iter : splitAltNames) {
1245             if (StartsWith(iter, DNS)) {
1246                 dnsNames.push_back(iter.substr(DNS_NAME_IDX));
1247             } else if (StartsWith(iter, IP_ADDRESS)) {
1248                 ips.push_back(iter.substr(IP_NAME_IDX));
1249             }
1250         }
1251     }
1252     std::tuple<bool, std::string> result;
1253     CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result);
1254     if (!std::get<0>(result)) {
1255         return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result);
1256     }
1257     return HOST_NAME + hostname + ". is cert's CN";
1258 }
1259 
RemoveConnect(int socketFd)1260 void TLSSocketServer::RemoveConnect(int socketFd)
1261 {
1262     std::shared_ptr<Connection> ptrConnection = nullptr;
1263     {
1264         std::lock_guard<std::mutex> its_lock(connectMutex_);
1265 
1266         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1267             if (it->second->GetSocketFd() == socketFd) {
1268                 ptrConnection = it->second;
1269                 break;
1270             } else {
1271                 ++it;
1272             }
1273         }
1274     }
1275     if (ptrConnection != nullptr) {
1276         ptrConnection->CallOnCloseCallback(static_cast<unsigned int>(socketFd));
1277         ptrConnection->Close();
1278     }
1279 }
1280 
RecvRemoteInfo(int socketFd,int index)1281 int TLSSocketServer::RecvRemoteInfo(int socketFd, int index)
1282 {
1283     {
1284         std::lock_guard<std::mutex> its_lock(connectMutex_);
1285         for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) {
1286             if (it->second->GetSocketFd() == socketFd) {
1287                 char buffer[MAX_BUFFER_SIZE];
1288                 if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) {
1289                     NETSTACK_LOGE("memcpy_s failed");
1290                     break;
1291                 }
1292                 int len = it->second->Recv(buffer, MAX_BUFFER_SIZE);
1293                 NETSTACK_LOGE("revc message is size is  %{public}d  buffer is   %{public}s ", len, buffer);
1294                 if (len > 0) {
1295                     Socket::SocketRemoteInfo remoteInfo;
1296                     remoteInfo.SetSize(strlen(buffer));
1297                     it->second->MakeRemoteInfo(remoteInfo);
1298                     it->second->CallOnMessageCallback(socketFd, buffer, remoteInfo);
1299                     return len;
1300                 }
1301 #if defined(CROSS_PLATFORM)
1302                 if (len == 0 &&  errno == 0) {
1303                     NETSTACK_LOGI("A client left");
1304                 }
1305 #endif
1306                 break;
1307             } else {
1308                 ++it;
1309             }
1310         }
1311     }
1312     RemoveConnect(socketFd);
1313     DropFdFromPollList(index);
1314     return -1;
1315 }
1316 
CallOnMessageCallback(int32_t socketFd,const std::string & data,const Socket::SocketRemoteInfo & remoteInfo)1317 void TLSSocketServer::Connection::CallOnMessageCallback(int32_t socketFd, const std::string &data,
1318                                                         const Socket::SocketRemoteInfo &remoteInfo)
1319 {
1320     OnMessageCallback CallBackfunc = nullptr;
1321     {
1322         if (onMessageCallback_) {
1323             CallBackfunc = onMessageCallback_;
1324         }
1325     }
1326 
1327     if (CallBackfunc) {
1328         while (!dataCache_->IsEmpty()) {
1329             CacheInfo cache = dataCache_->Get();
1330             CallBackfunc(socketFd, cache.data, cache.remoteInfo);
1331         }
1332         CallBackfunc(socketFd, data, remoteInfo);
1333     } else {
1334         CacheInfo cache = {data, remoteInfo};
1335         dataCache_->Set(cache);
1336     }
1337 }
1338 
AddConnect(int socketFd,std::shared_ptr<Connection> connection)1339 void TLSSocketServer::AddConnect(int socketFd, std::shared_ptr<Connection> connection)
1340 {
1341     std::lock_guard<std::mutex> its_lock(connectMutex_);
1342     clientIdConnections_[connection->GetClientID()] = connection;
1343 }
1344 
CallOnCloseCallback(const int32_t socketFd)1345 void TLSSocketServer::Connection::CallOnCloseCallback(const int32_t socketFd)
1346 {
1347     OnCloseCallback CallBackfunc = nullptr;
1348     {
1349         if (onCloseCallback_) {
1350             CallBackfunc = onCloseCallback_;
1351         }
1352     }
1353 
1354     if (CallBackfunc) {
1355         CallBackfunc(socketFd);
1356     }
1357 }
1358 
CallOnConnectCallback(const int32_t socketFd,std::shared_ptr<EventManager> eventManager)1359 void TLSSocketServer::CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager)
1360 {
1361     OnConnectCallback CallBackfunc = nullptr;
1362     {
1363         std::lock_guard<std::mutex> lock(mutex_);
1364         if (onConnectCallback_) {
1365             CallBackfunc = onConnectCallback_;
1366         }
1367     }
1368 
1369     if (CallBackfunc) {
1370         CallBackfunc(socketFd, eventManager);
1371     } else {
1372         NETSTACK_LOGE("CallOnConnectCallback  fun === null");
1373     }
1374 }
1375 
GetTlsConnectionLocalAddress(int acceptSockFD,Socket::NetAddress & localAddress)1376 bool TLSSocketServer::GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress)
1377 {
1378     struct sockaddr_storage addr{};
1379     socklen_t addrLen = sizeof(addr);
1380     if (getsockname(acceptSockFD, (struct sockaddr *)&addr, &addrLen) < 0) {
1381         if (acceptSockFD > 0) {
1382             close(acceptSockFD);
1383             CallOnErrorCallback(errno, strerror(errno));
1384             return false;
1385         }
1386     }
1387     char ipStr[INET6_ADDRSTRLEN] = {0};
1388     if (addr.ss_family == AF_INET) {
1389         auto *addr_in = (struct sockaddr_in *)&addr;
1390         inet_ntop(AF_INET, &addr_in->sin_addr, ipStr, sizeof(ipStr));
1391         localAddress.SetFamilyBySaFamily(AF_INET);
1392         localAddress.SetRawAddress(ipStr);
1393         localAddress.SetPort(ntohs(addr_in->sin_port));
1394     } else if (addr.ss_family == AF_INET6) {
1395         auto *addr_in6 = (struct sockaddr_in6 *)&addr;
1396         inet_ntop(AF_INET6, &addr_in6->sin6_addr, ipStr, sizeof(ipStr));
1397         localAddress.SetFamilyBySaFamily(AF_INET6);
1398         localAddress.SetRawAddress(ipStr);
1399         localAddress.SetPort(ntohs(addr_in6->sin6_port));
1400     }
1401     return true;
1402 }
1403 
ProcessTcpAccept(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID)1404 void TLSSocketServer::ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID)
1405 {
1406 #if !defined(CROSS_PLATFORM)
1407     struct sockaddr_in clientAddress;
1408     socklen_t clientAddrLength = sizeof(clientAddress);
1409     int connectFD = accept(listenSocketFd_, (struct sockaddr *)&clientAddress, &clientAddrLength);
1410     if (connectFD < 0) {
1411         int resErr = ConvertErrno();
1412         NETSTACK_LOGE("Server accept new client ERROR");
1413         CallOnErrorCallback(resErr, MakeErrnoString());
1414         return;
1415     }
1416     NETSTACK_LOGI("Server accept new client SUCCESS");
1417     std::shared_ptr<Connection> connection = std::make_shared<Connection>();
1418     Socket::NetAddress netAddress;
1419     Socket::NetAddress localAddress;
1420     char clientIp[INET6_ADDRSTRLEN] = {0};
1421     inet_ntop(address_.GetSaFamily(), &clientAddress.sin_addr, clientIp, INET_ADDRSTRLEN);
1422     int clientPort = ntohs(clientAddress.sin_port);
1423     netAddress.SetRawAddress(clientIp);
1424     netAddress.SetPort(clientPort);
1425     netAddress.SetFamilyBySaFamily(address_.GetSaFamily());
1426     connection->SetAddress(netAddress);
1427     if (GetTlsConnectionLocalAddress(connectFD, localAddress)) {
1428         connection->SetLocalAddress(localAddress);
1429     }
1430     SetTlsConnectionSecureOptions(tlsListenOptions, clientID, connectFD, connection);
1431 #endif
1432 }
SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions & tlsListenOptions,int clientID,int connectFD,std::shared_ptr<Connection> & connection)1433 void TLSSocketServer::SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID,
1434                                                     int connectFD, std::shared_ptr<Connection> &connection)
1435 {
1436     connection->SetClientID(clientID);
1437     auto res = connection->TlsAcceptToHost(connectFD, tlsListenOptions);
1438     if (!res) {
1439         int resErr = ConvertSSLError(connection->GetSSL());
1440         CallOnErrorCallback(resErr, MakeSSLErrorString(resErr));
1441         return;
1442     }
1443     if (g_userCounter >= USER_LIMIT) {
1444         const std::string info = "Too many users!";
1445         connection->Send(info);
1446         connection->Close();
1447         NETSTACK_LOGE("Too many users");
1448         close(connectFD);
1449         CallOnErrorCallback(-1, "Too many users");
1450         return;
1451     }
1452     g_userCounter++;
1453     fds_[g_userCounter].fd = connectFD;
1454 #if defined(CROSS_PLATFORM)
1455     fds_[g_userCounter].events = POLLIN | POLLERR;
1456 #else
1457     fds_[g_userCounter].events = POLLIN | POLLRDHUP | POLLERR;
1458 #endif
1459     fds_[g_userCounter].revents = 0;
1460     AddConnect(connectFD, connection);
1461     auto ptrEventManager = std::make_shared<EventManager>();
1462     EventManager::SetValid(ptrEventManager.get());
1463     ptrEventManager->SetData(this);
1464     connection->SetEventManager(ptrEventManager);
1465     CallOnConnectCallback(clientID, ptrEventManager);
1466     NETSTACK_LOGI("New client come in, fd is %{public}d", connectFD);
1467 }
1468 
InitPollList(int & listendFd)1469 void TLSSocketServer::InitPollList(int &listendFd)
1470 {
1471     for (int i = 1; i <= USER_LIMIT; ++i) {
1472         fds_[i].fd = -1;
1473         fds_[i].events = 0;
1474     }
1475     fds_[0].fd = listendFd;
1476     fds_[0].events = POLLIN | POLLERR;
1477     fds_[0].revents = 0;
1478 }
1479 
DropFdFromPollList(int & fd_index)1480 void TLSSocketServer::DropFdFromPollList(int &fd_index)
1481 {
1482     if (g_userCounter < 0) {
1483         NETSTACK_LOGE("g_userCounter  = %{public}d", g_userCounter);
1484         return;
1485     }
1486     fds_[fd_index].fd = fds_[g_userCounter].fd;
1487 
1488     fds_[g_userCounter].fd = -1;
1489     fds_[g_userCounter].events = 0;
1490     fd_index--;
1491     g_userCounter--;
1492     NETSTACK_LOGE("CallOnConnectCallback  g_userCounter  = %{public}d", g_userCounter);
1493 }
1494 
PollThread(const TlsSocket::TLSConnectOptions & tlsListenOptions)1495 void TLSSocketServer::PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions)
1496 {
1497     int on = 1;
1498     isRunning_ = true;
1499     ioctl(listenSocketFd_, FIONBIO, (char *)&on);
1500     NETSTACK_LOGE("PollThread  start working %{public}d", isRunning_);
1501     std::thread thread_([this, tlsOption = tlsListenOptions]() {
1502 #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM)
1503         pthread_setname_np(TLS_SOCKET_SERVER_READ);
1504 #else
1505         pthread_setname_np(pthread_self(), TLS_SOCKET_SERVER_READ);
1506 #endif
1507         InitPollList(listenSocketFd_);
1508         int clientId = 0;
1509         while (isRunning_) {
1510             int ret = poll(fds_, g_userCounter + 1, POLL_WAIT_TIME);
1511             if (ret < 0) {
1512                 int resErr = ConvertErrno();
1513                 NETSTACK_LOGE("Poll ERROR");
1514                 CallOnErrorCallback(resErr, MakeErrnoString());
1515                 break;
1516             }
1517             if (ret == 0) {
1518                 continue;
1519             }
1520             for (int i = 0; i < g_userCounter + 1; ++i) {
1521                 if ((fds_[i].fd == listenSocketFd_) && (static_cast<uint16_t>(fds_[i].revents) & POLLIN)) {
1522                     ProcessTcpAccept(tlsOption, ++clientId);
1523 #if !defined(CROSS_PLATFORM)
1524                 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLRDHUP) ||
1525                            (static_cast<uint16_t>(fds_[i].revents) & POLLERR)) {
1526 #else
1527                 } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLERR)) {
1528 #endif
1529                     RemoveConnect(fds_[i].fd);
1530                     DropFdFromPollList(i);
1531                     NETSTACK_LOGI("A client left");
1532                 } else if (static_cast<uint16_t>(fds_[i].revents) & POLLIN) {
1533                     RecvRemoteInfo(fds_[i].fd, i);
1534                 }
1535             }
1536         }
1537     });
1538     thread_.detach();
1539 }
1540 
GetConnectionByClientEventManager(const EventManager * eventManager)1541 std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientEventManager(
1542     const EventManager *eventManager)
1543 {
1544     std::lock_guard<std::mutex> its_lock(connectMutex_);
1545     auto it = std::find_if(clientIdConnections_.begin(), clientIdConnections_.end(), [eventManager](const auto& pair) {
1546         return pair.second->GetEventManager().get() == eventManager;
1547     });
1548     if (it == clientIdConnections_.end()) {
1549         return nullptr;
1550     }
1551     return it->second;
1552 }
1553 
CloseConnectionByEventManager(EventManager * eventManager)1554 void TLSSocketServer::CloseConnectionByEventManager(EventManager *eventManager)
1555 {
1556     std::shared_ptr<Connection> ptrConnection = GetConnectionByClientEventManager(eventManager);
1557 
1558     if (ptrConnection != nullptr) {
1559         ptrConnection->Close();
1560     }
1561 }
1562 
DeleteConnectionByEventManager(EventManager * eventManager)1563 void TLSSocketServer::DeleteConnectionByEventManager(EventManager *eventManager)
1564 {
1565     std::lock_guard<std::mutex> its_lock(connectMutex_);
1566     for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end(); ++it) {
1567         if (it->second->GetEventManager().get() == eventManager) {
1568             it = clientIdConnections_.erase(it);
1569             break;
1570         }
1571     }
1572 }
1573 } // namespace TlsSocketServer
1574 } // namespace NetStack
1575 } // namespace OHOS
1576