/* * Copyright (c) 2023-2024 Huawei Device Co., Ltd. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tls_socket_server.h" #include <chrono> #include <memory> #include <netinet/tcp.h> #include <numeric> #include <openssl/err.h> #include <openssl/ssl.h> #include <regex> #include <securec.h> #include <sys/ioctl.h> #include "base_context.h" #include "netstack_common_utils.h" #include "netstack_log.h" #include "tls.h" namespace OHOS { namespace NetStack { namespace TlsSocketServer { #if UNITTEST #else namespace { #endif // UNITTEST constexpr size_t MAX_ERR_LENGTH = 1024; constexpr int SSL_RET_CODE = 0; constexpr int BUF_SIZE = 2048; constexpr int POLL_WAIT_TIME = 2000; constexpr int OFFSET = 2; constexpr int SSL_ERROR_RETURN = -1; constexpr int REMOTE_CERT_LEN = 8192; constexpr int COMMON_NAME_BUF_SIZE = 256; constexpr int LISETEN_COUNT = 516; constexpr const char *SPLIT_HOST_NAME = "."; constexpr const char *SPLIT_ALT_NAMES = ","; constexpr const char *DNS = "DNS:"; constexpr const char *HOST_NAME = "hostname: "; constexpr const char *IP_ADDRESS = "IP Address:"; constexpr const char *SIGN_NID_RSA = "RSA+"; constexpr const char *SIGN_NID_RSA_PSS = "RSA-PSS+"; constexpr const char *SIGN_NID_DSA = "DSA+"; constexpr const char *SIGN_NID_ECDSA = "ECDSA+"; constexpr const char *SIGN_NID_ED = "Ed25519+"; constexpr const char *SIGN_NID_ED_FOUR_FOUR_EIGHT = "Ed448+"; constexpr const char *SIGN_NID_UNDEF_ADD = "UNDEF+"; constexpr const char *PROTOCOL_UNKNOW = "UNKNOW_PROTOCOL"; constexpr const char *SIGN_NID_UNDEF = "UNDEF"; constexpr const char *OPERATOR_PLUS_SIGN = "+"; constexpr const char *UNKNOW_REASON = "Unknown reason"; constexpr const char *IP = "IP: "; static constexpr const char *TLS_SOCKET_SERVER_READ = "OS_NET_TSAccRD"; const std::regex JSON_STRING_PATTERN{R"(/^"(?:[^"\\\u0000-\u001f]|\\(?:["\\/bfnrt]|u[0-9a-fA-F]{4}))*"/)"}; const std::regex PATTERN{ "((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|" "2[0-4][0-9]|[01]?[0-9][0-9]?)"}; int g_userCounter = 0; bool IsIP(const std::string &ip) { std::regex pattern(PATTERN); std::smatch res; return regex_match(ip, res, pattern); } std::vector<std::string> SplitHostName(std::string &hostName) { transform(hostName.begin(), hostName.end(), hostName.begin(), ::tolower); return CommonUtils::Split(hostName, SPLIT_HOST_NAME); } bool SeekIntersection(std::vector<std::string> &vecA, std::vector<std::string> &vecB) { std::vector<std::string> result; set_intersection(vecA.begin(), vecA.end(), vecB.begin(), vecB.end(), inserter(result, result.begin())); return !result.empty(); } int ConvertErrno() { return TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE + errno; } int ConvertSSLError(ssl_st *ssl) { if (!ssl) { return TlsSocket::TLS_ERR_SSL_NULL; } return TlsSocket::TlsSocketError::TLS_ERR_SSL_BASE + SSL_get_error(ssl, SSL_RET_CODE); } std::string MakeErrnoString() { return strerror(errno); } std::string MakeSSLErrorString(int error) { char err[MAX_ERR_LENGTH] = {0}; ERR_error_string_n(error - TlsSocket::TlsSocketError::TLS_ERR_SYS_BASE, err, sizeof(err)); return err; } std::vector<std::string> SplitEscapedAltNames(std::string &altNames) { std::vector<std::string> result; std::string currentToken; size_t offset = 0; while (offset != altNames.length()) { auto nextSep = altNames.find_first_of(", "); auto nextQuote = altNames.find_first_of('\"'); if (nextQuote != std::string::npos && (nextSep != std::string::npos || nextQuote < nextSep)) { currentToken += altNames.substr(offset, nextQuote); std::regex jsonStringPattern(JSON_STRING_PATTERN); std::smatch match; std::string altNameSubStr = altNames.substr(nextQuote); bool ret = regex_match(altNameSubStr, match, jsonStringPattern); if (!ret) { return {""}; } currentToken += result[0]; offset = nextQuote + result[0].length(); } else if (nextSep != std::string::npos) { currentToken += altNames.substr(offset, nextSep); result.push_back(currentToken); currentToken = ""; offset = nextSep + OFFSET; } else { currentToken += altNames.substr(offset); offset = altNames.length(); } } result.push_back(currentToken); return result; } #if UNITTEST #else } // namespace #endif void TLSServerSendOptions::SetSocket(const int &socketFd) { socketFd_ = socketFd; } void TLSServerSendOptions::SetSendData(const std::string &data) { data_ = data; } const int &TLSServerSendOptions::GetSocket() const { return socketFd_; } const std::string &TLSServerSendOptions::GetSendData() const { return data_; } TLSSocketServer::~TLSSocketServer() { isRunning_ = false; clientIdConnections_.clear(); if (listenSocketFd_ != -1) { shutdown(listenSocketFd_, SHUT_RDWR); close(listenSocketFd_); listenSocketFd_ = -1; } } void TLSSocketServer::Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback) { if (!CommonUtils::HasInternetPermission()) { CallListenCallback(PERMISSION_DENIED_CODE, callback); return; } NETSTACK_LOGE("Listen 1 %{public}d", listenSocketFd_); if (listenSocketFd_ >= 0) { CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback); return; } NETSTACK_LOGE("Listen 2 %{public}d", listenSocketFd_); if (ExecBind(tlsListenOptions.GetNetAddress(), callback)) { NETSTACK_LOGE("Listen 3 %{public}d", listenSocketFd_); ExecAccept(tlsListenOptions, callback); } else { shutdown(listenSocketFd_, SHUT_RDWR); close(listenSocketFd_); listenSocketFd_ = -1; } PollThread(tlsListenOptions); } bool TLSSocketServer::ExecBind(const Socket::NetAddress &address, const ListenCallback &callback) { MakeIpSocket(address.GetSaFamily()); if (listenSocketFd_ < 0) { int resErr = ConvertErrno(); NETSTACK_LOGE("make tcp socket failed errno is %{public}d %{public}s", errno, MakeErrnoString().c_str()); CallOnErrorCallback(resErr, MakeErrnoString()); CallListenCallback(resErr, callback); return false; } sockaddr_in addr4 = {0}; sockaddr_in6 addr6 = {0}; sockaddr *addr = nullptr; socklen_t len; GetAddr(address, &addr4, &addr6, &addr, &len); if (addr == nullptr) { NETSTACK_LOGE("TLSSocket::Bind Address Is Invalid"); CallOnErrorCallback(-1, "Address Is Invalid"); CallListenCallback(ConvertErrno(), callback); return false; } if (bind(listenSocketFd_, addr, len) < 0) { if (errno != EADDRINUSE) { NETSTACK_LOGE("bind error is %{public}s %{public}d", strerror(errno), errno); CallOnErrorCallback(-1, "Address binding failed"); CallListenCallback(ConvertErrno(), callback); return false; } if (addr->sa_family == AF_INET) { NETSTACK_LOGI("distribute a random port"); addr4.sin_port = 0; /* distribute a random port */ } else if (addr->sa_family == AF_INET6) { NETSTACK_LOGI("distribute a random port"); addr6.sin6_port = 0; /* distribute a random port */ } if (bind(listenSocketFd_, addr, len) < 0) { NETSTACK_LOGE("rebind error is %{public}s %{public}d", strerror(errno), errno); CallOnErrorCallback(-1, "Duplicate binding address failed"); CallListenCallback(ConvertErrno(), callback); return false; } NETSTACK_LOGI("rebind success"); } NETSTACK_LOGI("bind success"); address_ = address; return true; } void TLSSocketServer::ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback) { if (listenSocketFd_ < 0) { int resErr = ConvertErrno(); NETSTACK_LOGE("accept error is %{public}s %{public}d", MakeErrnoString().c_str(), errno); CallOnErrorCallback(resErr, MakeErrnoString()); callback(resErr); return; } SetLocalTlsConfiguration(tlsAcceptOptions); int ret = 0; NETSTACK_LOGE( "accept error is listenSocketFd_= %{public}d LISETEN_COUNT =%{public}d .GetVerifyMode() = %{public}d ", listenSocketFd_, LISETEN_COUNT, tlsAcceptOptions.GetTlsSecureOptions().GetVerifyMode()); ret = listen(listenSocketFd_, LISETEN_COUNT); if (ret < 0) { int resErr = ConvertErrno(); NETSTACK_LOGE("tcp server listen error"); CallOnErrorCallback(resErr, MakeErrnoString()); callback(resErr); return; } CallListenCallback(TlsSocket::TLSSOCKET_SUCCESS, callback); } bool TLSSocketServer::Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback) { int socketFd = data.GetSocket(); std::string info = data.GetSendData(); auto connect_iterator = clientIdConnections_.find(socketFd); if (connect_iterator == clientIdConnections_.end()) { NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd); CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd"); return false; } auto connect = connect_iterator->second; auto res = connect->Send(info); if (!res) { int resErr = ConvertSSLError(connect->GetSSL()); CallOnErrorCallback(resErr, MakeSSLErrorString(resErr)); CallSendCallback(resErr, callback); return false; } CallSendCallback(TlsSocket::TLSSOCKET_SUCCESS, callback); return res; } void TLSSocketServer::CallSendCallback(int32_t err, TlsSocket::SendCallback callback) { TlsSocket::SendCallback CallBackfunc = nullptr; { std::lock_guard<std::mutex> lock(mutex_); if (callback) { CallBackfunc = callback; } } if (CallBackfunc) { CallBackfunc(err); } } void TLSSocketServer::Close(const int socketFd, const TlsSocket::CloseCallback &callback) { { std::lock_guard<std::mutex> its_lock(connectMutex_); for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) { if (it->first == socketFd) { auto res = it->second->Close(); if (!res) { int resErr = ConvertSSLError(it->second->GetSSL()); NETSTACK_LOGE("close error is %{public}s %{public}d", MakeSSLErrorString(resErr).c_str(), resErr); CallOnErrorCallback(resErr, MakeSSLErrorString(resErr)); callback(resErr); return; } callback(TlsSocket::TLSSOCKET_SUCCESS); return; } else { ++it; } } } NETSTACK_LOGE("socket = %{public}d There is no corresponding socketFd", socketFd); CallOnErrorCallback(-1, "The send failed with no corresponding socketFd"); callback(TlsSocket::TLS_ERR_SYS_EINVAL); } void TLSSocketServer::Stop(const TlsSocket::CloseCallback &callback) { std::lock_guard<std::mutex> its_lock(connectMutex_); for (const auto &c : clientIdConnections_) { c.second->Close(); } clientIdConnections_.clear(); close(listenSocketFd_); listenSocketFd_ = -1; callback(TlsSocket::TLSSOCKET_SUCCESS); } void TLSSocketServer::GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback) { auto connect_iterator = clientIdConnections_.find(socketFd); if (connect_iterator == clientIdConnections_.end()) { NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd); CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd"); callback(TlsSocket::TLS_ERR_SYS_EINVAL, {}); return; } auto connect = connect_iterator->second; auto address = connect->GetAddress(); callback(TlsSocket::TLSSOCKET_SUCCESS, address); } void TLSSocketServer::GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback) { auto connect_iterator = clientIdConnections_.find(socketFd); if (connect_iterator == clientIdConnections_.end()) { NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd); CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd"); callback(TlsSocket::TLS_ERR_SYS_EINVAL, {}); return; } auto connect = connect_iterator->second; auto localAddress = connect->GetLocalAddress(); callback(TlsSocket::TLSSOCKET_SUCCESS, localAddress); } void TLSSocketServer::GetState(const TlsSocket::GetStateCallback &callback) { int opt; socklen_t optLen = sizeof(int); int r = getsockopt(listenSocketFd_, SOL_SOCKET, SO_TYPE, &opt, &optLen); if (r < 0) { Socket::SocketStateBase state; state.SetIsClose(true); CallGetStateCallback(ConvertErrno(), state, callback); return; } sockaddr sockAddr = {0}; socklen_t len = sizeof(sockaddr); Socket::SocketStateBase state; int ret = getsockname(listenSocketFd_, &sockAddr, &len); state.SetIsBound(ret == 0); ret = getpeername(listenSocketFd_, &sockAddr, &len); if (ret != 0) { NETSTACK_LOGE("getpeername failed"); } state.SetIsConnected(GetConnectionClientCount() > 0); CallGetStateCallback(TlsSocket::TLSSOCKET_SUCCESS, state, callback); } void TLSSocketServer::CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, TlsSocket::GetStateCallback callback) { TlsSocket::GetStateCallback CallBackfunc = nullptr; { std::lock_guard<std::mutex> lock(mutex_); if (callback) { CallBackfunc = callback; } } if (CallBackfunc) { CallBackfunc(err, state); } } bool TLSSocketServer::SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const TlsSocket::SetExtraOptionsCallback &callback) { if (tcpExtraOptions.IsKeepAlive()) { int keepalive = 1; if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) { return false; } } if (tcpExtraOptions.IsOOBInline()) { int oobInline = 1; if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_OOBINLINE, &oobInline, sizeof(oobInline)) < 0) { return false; } } if (tcpExtraOptions.IsTCPNoDelay()) { int tcpNoDelay = 1; if (setsockopt(listenSocketFd_, IPPROTO_TCP, TCP_NODELAY, &tcpNoDelay, sizeof(tcpNoDelay)) < 0) { return false; } } linger soLinger = {0}; soLinger.l_onoff = tcpExtraOptions.socketLinger.IsOn(); soLinger.l_linger = (int)tcpExtraOptions.socketLinger.GetLinger(); if (setsockopt(listenSocketFd_, SOL_SOCKET, SO_LINGER, &soLinger, sizeof(soLinger)) < 0) { return false; } return true; } void TLSSocketServer::SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config) { TLSServerConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass()); TLSServerConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert()); TLSServerConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain()); TLSServerConfiguration_.SetVerifyMode(config.GetTlsSecureOptions().GetVerifyMode()); const auto protocolVec = config.GetTlsSecureOptions().GetProtocolChain(); if (!protocolVec.empty()) { TLSServerConfiguration_.SetProtocol(protocolVec); } } void TLSSocketServer::GetCertificate(const TlsSocket::GetCertificateCallback &callback) { const auto &cert = TLSServerConfiguration_.GetCertificate(); NETSTACK_LOGI("cert der is %{public}d", cert.encodingFormat); if (!cert.data.Length()) { CallOnErrorCallback(-1, "cert not data Length"); callback(-1, {}); return; } callback(TlsSocket::TLSSOCKET_SUCCESS, cert); } void TLSSocketServer::GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback) { auto connect_iterator = clientIdConnections_.find(socketFd); if (connect_iterator == clientIdConnections_.end()) { NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd); CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd"); callback(TlsSocket::TLS_ERR_SYS_EINVAL, {}); return; } auto connect = connect_iterator->second; const auto &remoteCert = connect->GetRemoteCertRawData(); if (!remoteCert.data.Length()) { int resErr = ConvertSSLError(connect->GetSSL()); NETSTACK_LOGE("GetRemoteCertificate errno %{public}d, %{public}s", resErr, MakeSSLErrorString(resErr).c_str()); CallOnErrorCallback(resErr, MakeSSLErrorString(resErr)); callback(resErr, {}); return; } callback(TlsSocket::TLSSOCKET_SUCCESS, remoteCert); } void TLSSocketServer::GetProtocol(const TlsSocket::GetProtocolCallback &callback) { if (TLSServerConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) { callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V13); return; } callback(TlsSocket::TLSSOCKET_SUCCESS, TlsSocket::PROTOCOL_TLS_V12); } void TLSSocketServer::GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback) { auto connect_iterator = clientIdConnections_.find(socketFd); if (connect_iterator == clientIdConnections_.end()) { NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd); CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd"); callback(TlsSocket::TLS_ERR_SYS_EINVAL, {}); return; } auto connect = connect_iterator->second; auto cipherSuite = connect->GetCipherSuite(); if (cipherSuite.empty()) { NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno); int resErr = ConvertSSLError(connect->GetSSL()); CallOnErrorCallback(resErr, MakeSSLErrorString(resErr)); callback(resErr, cipherSuite); return; } callback(TlsSocket::TLSSOCKET_SUCCESS, cipherSuite); } void TLSSocketServer::GetSignatureAlgorithms(const int socketFd, const TlsSocket::GetSignatureAlgorithmsCallback &callback) { auto connect_iterator = clientIdConnections_.find(socketFd); if (connect_iterator == clientIdConnections_.end()) { NETSTACK_LOGE("socket = %{public}d The connection has been disconnected", socketFd); CallOnErrorCallback(TlsSocket::TLS_ERR_SYS_EINVAL, "The send failed with no corresponding socketFd"); callback(TlsSocket::TLS_ERR_SYS_EINVAL, {}); return; } auto connect = connect_iterator->second; auto signatureAlgorithms = connect->GetSignatureAlgorithms(); if (signatureAlgorithms.empty()) { NETSTACK_LOGE("GetCipherSuite errno %{public}d", errno); int resErr = ConvertSSLError(connect->GetSSL()); CallOnErrorCallback(resErr, MakeSSLErrorString(resErr)); callback(resErr, signatureAlgorithms); return; } callback(TlsSocket::TLSSOCKET_SUCCESS, signatureAlgorithms); } void TLSSocketServer::Connection::OnMessage(const OnMessageCallback &onMessageCallback) { onMessageCallback_ = onMessageCallback; } void TLSSocketServer::Connection::OnClose(const OnCloseCallback &onCloseCallback) { onCloseCallback_ = onCloseCallback; } void TLSSocketServer::OnConnect(const OnConnectCallback &onConnectCallback) { std::lock_guard<std::mutex> lock(mutex_); onConnectCallback_ = onConnectCallback; } void TLSSocketServer::OnError(const TlsSocket::OnErrorCallback &onErrorCallback) { std::lock_guard<std::mutex> lock(mutex_); onErrorCallback_ = onErrorCallback; } void TLSSocketServer::Connection::OffMessage() { if (onMessageCallback_) { onMessageCallback_ = nullptr; } } void TLSSocketServer::OffConnect() { std::lock_guard<std::mutex> lock(mutex_); if (onConnectCallback_) { onConnectCallback_ = nullptr; } } void TLSSocketServer::Connection::OnError(const TlsSocket::OnErrorCallback &onErrorCallback) { onErrorCallback_ = onErrorCallback; } void TLSSocketServer::Connection::OffClose() { if (onCloseCallback_) { onCloseCallback_ = nullptr; } } void TLSSocketServer::Connection::OffError() { onErrorCallback_ = nullptr; } void TLSSocketServer::Connection::CallOnErrorCallback(int32_t err, const std::string &errString) { TlsSocket::OnErrorCallback CallBackfunc = nullptr; { if (onErrorCallback_) { CallBackfunc = onErrorCallback_; } } if (CallBackfunc) { CallBackfunc(err, errString); } } void TLSSocketServer::OffError() { std::lock_guard<std::mutex> lock(mutex_); if (onErrorCallback_) { onErrorCallback_ = nullptr; } } void TLSSocketServer::MakeIpSocket(sa_family_t family) { if (family != AF_INET && family != AF_INET6) { return; } int sock = socket(family, SOCK_STREAM, IPPROTO_IP); if (sock < 0) { int resErr = ConvertErrno(); NETSTACK_LOGE("Create socket failed (%{public}d:%{public}s)", errno, MakeErrnoString().c_str()); CallOnErrorCallback(resErr, MakeErrnoString()); return; } listenSocketFd_ = sock; } void TLSSocketServer::CallOnErrorCallback(int32_t err, const std::string &errString) { TlsSocket::OnErrorCallback CallBackfunc = nullptr; { std::lock_guard<std::mutex> lock(mutex_); if (onErrorCallback_) { CallBackfunc = onErrorCallback_; } } if (CallBackfunc) { CallBackfunc(err, errString); } } void TLSSocketServer::GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, socklen_t *len) { if (!addr6 || !addr4 || !len) { return; } sa_family_t family = address.GetSaFamily(); if (family == AF_INET) { addr4->sin_family = AF_INET; addr4->sin_port = htons(address.GetPort()); addr4->sin_addr.s_addr = inet_addr(address.GetAddress().c_str()); *addr = reinterpret_cast<sockaddr *>(addr4); *len = sizeof(sockaddr_in); } else if (family == AF_INET6) { addr6->sin6_family = AF_INET6; addr6->sin6_port = htons(address.GetPort()); inet_pton(AF_INET6, address.GetAddress().c_str(), &addr6->sin6_addr); *addr = reinterpret_cast<sockaddr *>(addr6); *len = sizeof(sockaddr_in6); } } int TLSSocketServer::GetListenSocketFd() { return listenSocketFd_; } void TLSSocketServer::SetLocalAddress(const Socket::NetAddress &address) { localAddress_ = address; } Socket::NetAddress TLSSocketServer::GetLocalAddress() { return localAddress_; } std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientID(int clientid) { std::shared_ptr<Connection> ptrConnection = nullptr; auto it = clientIdConnections_.find(clientid); if (it != clientIdConnections_.end()) { ptrConnection = it->second; } return ptrConnection; } int TLSSocketServer::GetConnectionClientCount() { return g_userCounter; } void TLSSocketServer::CallListenCallback(int32_t err, ListenCallback callback) { ListenCallback CallBackfunc = nullptr; { std::lock_guard<std::mutex> lock(mutex_); if (callback) { CallBackfunc = callback; } } if (CallBackfunc) { CallBackfunc(err); } } void TLSSocketServer::Connection::SetAddress(const Socket::NetAddress address) { address_ = address; } void TLSSocketServer::Connection::SetLocalAddress(const Socket::NetAddress address) { localAddress_ = address; } const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetRemoteCertRawData() const { return remoteRawData_; } TLSSocketServer::Connection::~Connection() { Close(); } bool TLSSocketServer::Connection::TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options) { SetTlsConfiguration(options); std::string cipherSuite = options.GetTlsSecureOptions().GetCipherSuite(); if (!cipherSuite.empty()) { connectionConfiguration_.SetCipherSuite(cipherSuite); } std::string signatureAlgorithms = options.GetTlsSecureOptions().GetSignatureAlgorithms(); if (!signatureAlgorithms.empty()) { connectionConfiguration_.SetSignatureAlgorithms(signatureAlgorithms); } const auto protocolVec = options.GetTlsSecureOptions().GetProtocolChain(); if (!protocolVec.empty()) { connectionConfiguration_.SetProtocol(protocolVec); } connectionConfiguration_.SetVerifyMode(options.GetTlsSecureOptions().GetVerifyMode()); socketFd_ = sock; return StartTlsAccept(options); } void TLSSocketServer::Connection::SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config) { connectionConfiguration_.SetPrivateKey(config.GetTlsSecureOptions().GetKey(), config.GetTlsSecureOptions().GetKeyPass()); connectionConfiguration_.SetLocalCertificate(config.GetTlsSecureOptions().GetCert()); connectionConfiguration_.SetCaCertificate(config.GetTlsSecureOptions().GetCaChain()); connectionConfiguration_.SetNetAddress(config.GetNetAddress()); } bool TLSSocketServer::Connection::Send(const std::string &data) { if (!ssl_) { NETSTACK_LOGE("ssl is null"); return false; } if (data.empty()) { NETSTACK_LOGI("data is empty"); return true; } int len = SSL_write(ssl_, data.c_str(), data.length()); if (len < 0) { int resErr = ConvertSSLError(GetSSL()); NETSTACK_LOGE("data '%{public}s' send failed!The error code is %{public}d, The error message is'%{public}s'", data.c_str(), resErr, MakeSSLErrorString(resErr).c_str()); return false; } NETSTACK_LOGD("data '%{public}s' Sent successfully,sent in total %{public}d bytes!", data.c_str(), len); return true; } int TLSSocketServer::Connection::Recv(char *buffer, int maxBufferSize) { if (!ssl_) { NETSTACK_LOGE("ssl is null"); return SSL_ERROR_RETURN; } return SSL_read(ssl_, buffer, maxBufferSize); } bool TLSSocketServer::Connection::Close() { if (!ssl_) { NETSTACK_LOGE("ssl is null"); return false; } int result = SSL_shutdown(ssl_); if (result < 0) { int resErr = ConvertSSLError(GetSSL()); NETSTACK_LOGE("Error in shutdown, errno is %{public}d, error info is %{public}s", resErr, MakeSSLErrorString(resErr).c_str()); } SSL_free(ssl_); ssl_ = nullptr; if (socketFd_ != -1) { shutdown(socketFd_, SHUT_RDWR); close(socketFd_); socketFd_ = -1; } if (!tlsContextServerPointer_) { NETSTACK_LOGE("Tls context pointer is null"); return false; } tlsContextServerPointer_->CloseCtx(); return true; } bool TLSSocketServer::Connection::SetAlpnProtocols(const std::vector<std::string> &alpnProtocols) { if (!ssl_) { NETSTACK_LOGE("ssl is null"); return false; } size_t pos = 0; size_t len = std::accumulate(alpnProtocols.begin(), alpnProtocols.end(), static_cast<size_t>(0), [](size_t init, const std::string &alpnProt) { return init + alpnProt.length(); }); auto result = std::make_unique<unsigned char[]>(alpnProtocols.size() + len); for (const auto &str : alpnProtocols) { len = str.length(); result[pos++] = len; if (!strcpy_s(reinterpret_cast<char *>(&result[pos]), len, str.c_str())) { NETSTACK_LOGE("strcpy_s failed"); return false; } pos += len; } result[pos] = '\0'; NETSTACK_LOGD("alpnProtocols after splicing %{public}s", result.get()); if (SSL_set_alpn_protos(ssl_, result.get(), pos)) { int resErr = ConvertSSLError(GetSSL()); NETSTACK_LOGE("Failed to set negotiable protocol list, errno is %{public}d, error info is %{public}s", resErr, MakeSSLErrorString(resErr).c_str()); return false; } return true; } void TLSSocketServer::Connection::MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo) { remoteInfo.SetAddress(address_.GetAddress()); remoteInfo.SetPort(address_.GetPort()); remoteInfo.SetFamily(address_.GetSaFamily()); } TlsSocket::TLSConfiguration TLSSocketServer::Connection::GetTlsConfiguration() const { return connectionConfiguration_; } std::vector<std::string> TLSSocketServer::Connection::GetCipherSuite() const { if (!ssl_) { NETSTACK_LOGE("ssl in null"); return {}; } STACK_OF(SSL_CIPHER) *sk = SSL_get_ciphers(ssl_); if (!sk) { NETSTACK_LOGE("get ciphers failed"); return {}; } TlsSocket::CipherSuite cipherSuite; std::vector<std::string> cipherSuiteVec; for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) { const SSL_CIPHER *c = sk_SSL_CIPHER_value(sk, i); cipherSuite.cipherName_ = SSL_CIPHER_get_name(c); cipherSuiteVec.push_back(cipherSuite.cipherName_); } return cipherSuiteVec; } std::string TLSSocketServer::Connection::GetRemoteCertificate() const { return remoteCert_; } const TlsSocket::X509CertRawData &TLSSocketServer::Connection::GetCertificate() const { return connectionConfiguration_.GetCertificate(); } std::vector<std::string> TLSSocketServer::Connection::GetSignatureAlgorithms() const { return signatureAlgorithms_; } std::string TLSSocketServer::Connection::GetProtocol() const { if (!ssl_) { NETSTACK_LOGE("ssl in null"); return PROTOCOL_UNKNOW; } if (connectionConfiguration_.GetProtocol() == TlsSocket::TLS_V1_3) { return TlsSocket::PROTOCOL_TLS_V13; } return TlsSocket::PROTOCOL_TLS_V12; } bool TLSSocketServer::Connection::SetSharedSigals() { if (!ssl_) { NETSTACK_LOGE("ssl is null"); return false; } int number = SSL_get_shared_sigalgs(ssl_, 0, nullptr, nullptr, nullptr, nullptr, nullptr); if (!number) { NETSTACK_LOGE("SSL_get_shared_sigalgs return value error"); return false; } for (int i = 0; i < number; i++) { int hash_nid; int sign_nid; std::string sig_with_md; SSL_get_shared_sigalgs(ssl_, i, &sign_nid, &hash_nid, nullptr, nullptr, nullptr); switch (sign_nid) { case EVP_PKEY_RSA: sig_with_md = SIGN_NID_RSA; break; case EVP_PKEY_RSA_PSS: sig_with_md = SIGN_NID_RSA_PSS; break; case EVP_PKEY_DSA: sig_with_md = SIGN_NID_DSA; break; case EVP_PKEY_EC: sig_with_md = SIGN_NID_ECDSA; break; case NID_ED25519: sig_with_md = SIGN_NID_ED; break; case NID_ED448: sig_with_md = SIGN_NID_ED_FOUR_FOUR_EIGHT; break; default: const char *sn = OBJ_nid2sn(sign_nid); sig_with_md = (sn != nullptr) ? (std::string(sn) + OPERATOR_PLUS_SIGN) : SIGN_NID_UNDEF_ADD; } const char *sn_hash = OBJ_nid2sn(hash_nid); sig_with_md += (sn_hash != nullptr) ? std::string(sn_hash) : SIGN_NID_UNDEF; signatureAlgorithms_.push_back(sig_with_md); } return true; } ssl_st *TLSSocketServer::Connection::GetSSL() const { return ssl_; } Socket::NetAddress TLSSocketServer::Connection::GetAddress() const { return address_; } Socket::NetAddress TLSSocketServer::Connection::GetLocalAddress() const { return localAddress_; } int TLSSocketServer::Connection::GetSocketFd() const { return socketFd_; } std::shared_ptr<EventManager> TLSSocketServer::Connection::GetEventManager() const { return eventManager_; } void TLSSocketServer::Connection::SetEventManager(std::shared_ptr<EventManager> eventManager) { eventManager_ = eventManager; } void TLSSocketServer::Connection::SetClientID(int32_t clientID) { clientID_ = clientID; } int TLSSocketServer::Connection::GetClientID() { return clientID_; } bool TLSSocketServer::Connection::StartTlsAccept(const TlsSocket::TLSConnectOptions &options) { if (!CreatTlsContext()) { NETSTACK_LOGE("failed to create tls context"); return false; } if (!StartShakingHands(options)) { NETSTACK_LOGE("failed to shaking hands"); return false; } return true; } bool TLSSocketServer::Connection::CreatTlsContext() { tlsContextServerPointer_ = TlsSocket::TLSContextServer::CreateConfiguration(connectionConfiguration_); if (!tlsContextServerPointer_) { NETSTACK_LOGE("failed to create tls context pointer"); return false; } if (!(ssl_ = tlsContextServerPointer_->CreateSsl())) { NETSTACK_LOGE("failed to create ssl session"); return false; } SSL_set_fd(ssl_, socketFd_); SSL_set_accept_state(ssl_); return true; } bool TLSSocketServer::Connection::StartShakingHands(const TlsSocket::TLSConnectOptions &options) { if (!ssl_) { NETSTACK_LOGE("ssl is null"); return false; } int result = SSL_accept(ssl_); if (result == -1) { int errorStatus = ConvertSSLError(ssl_); NETSTACK_LOGE("SSL connect is error, errno is %{public}d, error info is %{public}s", errorStatus, MakeSSLErrorString(errorStatus).c_str()); return false; } std::vector<std::string> SslProtocolVer({SSL_get_version(ssl_)}); connectionConfiguration_.SetProtocol({SslProtocolVer}); std::string list = SSL_get_cipher_list(ssl_, 0); NETSTACK_LOGI("SSL_get_cipher_list: %{public}s", list.c_str()); connectionConfiguration_.SetCipherSuite(list); if (!SetSharedSigals()) { NETSTACK_LOGE("Failed to set sharedSigalgs"); } if (!GetRemoteCertificateFromPeer()) { NETSTACK_LOGE("Failed to get remote certificate"); } if (peerX509_ != nullptr) { NETSTACK_LOGE("peer x509Certificates is null"); if (!SetRemoteCertRawData()) { NETSTACK_LOGE("Failed to set remote x509 certificata Serialization data"); } TlsSocket::CheckServerIdentity checkServerIdentity = options.GetCheckServerIdentity(); if (!checkServerIdentity) { CheckServerIdentityLegal(hostName_, peerX509_); } else { checkServerIdentity(hostName_, {remoteCert_}); } } return true; } bool TLSSocketServer::Connection::GetRemoteCertificateFromPeer() { peerX509_ = SSL_get_peer_certificate(ssl_); if (SSL_get_verify_result(ssl_) == X509_V_OK) { NETSTACK_LOGE("SSL_get_verify_result ==X509_V_OK"); } if (peerX509_ == nullptr) { int resErr = ConvertSSLError(GetSSL()); NETSTACK_LOGE("open fail errno, errno is %{public}d, error info is %{public}s", resErr, MakeSSLErrorString(resErr).c_str()); return false; } BIO *bio = BIO_new(BIO_s_mem()); if (!bio) { NETSTACK_LOGE("TlsSocket::SetRemoteCertificate bio is null"); return false; } X509_print(bio, peerX509_); char data[REMOTE_CERT_LEN] = {0}; if (!BIO_read(bio, data, REMOTE_CERT_LEN)) { NETSTACK_LOGE("BIO_read function returns error"); BIO_free(bio); return false; } BIO_free(bio); remoteCert_ = std::string(data); return true; } bool TLSSocketServer::Connection::SetRemoteCertRawData() { if (peerX509_ == nullptr) { NETSTACK_LOGE("peerX509 is null"); return false; } int32_t length = i2d_X509(peerX509_, nullptr); if (length <= 0) { NETSTACK_LOGE("Failed to convert peerX509 to der format"); return false; } unsigned char *der = nullptr; (void)i2d_X509(peerX509_, &der); TlsSocket::SecureData data(der, length); remoteRawData_.data = data; OPENSSL_free(der); remoteRawData_.encodingFormat = TlsSocket::EncodingFormat::DER; return true; } static bool StartsWith(const std::string &s, const std::string &prefix) { return s.size() >= prefix.size() && s.compare(0, prefix.size(), prefix) == 0; } void CheckIpAndDnsName(const std::string &hostName, std::vector<std::string> &dnsNames, std::vector<std::string> &ips, const X509 *x509Certificates, std::tuple<bool, std::string> &result) { bool valid = false; std::string reason = UNKNOW_REASON; int index = X509_get_ext_by_NID(x509Certificates, NID_commonName, -1); if (IsIP(hostName)) { auto it = find(ips.begin(), ips.end(), hostName); if (it == ips.end()) { reason = IP + hostName + " is not in the cert's list"; } result = {valid, reason}; return; } std::string tempHostName = "" + hostName; if (!dnsNames.empty() || index > 0) { std::vector<std::string> hostParts = SplitHostName(tempHostName); std::string tmpStr = ""; if (!dnsNames.empty()) { valid = SeekIntersection(hostParts, dnsNames); tmpStr = ". is not in the cert's altnames"; } else { char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0}; X509_NAME *pSubName = nullptr; int len = X509_NAME_get_text_by_NID(pSubName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE); if (len > 0) { std::vector<std::string> commonNameVec; commonNameVec.emplace_back(commonNameBuf); valid = SeekIntersection(hostParts, commonNameVec); tmpStr = ". is not cert's CN"; } } if (!valid) { reason = HOST_NAME + tempHostName + tmpStr; } result = {valid, reason}; return; } reason = "Cert does not contain a DNS name"; result = {valid, reason}; } std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates) { X509_NAME *subjectName = X509_get_subject_name(x509Certificates); if (!subjectName) { return "subject name is null"; } char subNameBuf[BUF_SIZE] = {0}; X509_NAME_oneline(subjectName, subNameBuf, BUF_SIZE); int index = X509_get_ext_by_NID(x509Certificates, NID_subject_alt_name, -1); if (index < 0) { return "X509 get ext nid error"; } X509_EXTENSION *ext = X509_get_ext(x509Certificates, index); if (ext == nullptr) { return "X509 get ext error"; } ASN1_OBJECT *obj = nullptr; obj = X509_EXTENSION_get_object(ext); char subAltNameBuf[BUF_SIZE] = {0}; OBJ_obj2txt(subAltNameBuf, BUF_SIZE, obj, 0); NETSTACK_LOGD("extions obj : %{public}s\n", subAltNameBuf); return CheckServerIdentityLegal(hostName, ext, x509Certificates); } std::string TLSSocketServer::Connection::CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, const X509 *x509Certificates) { ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext); if (!extData) { NETSTACK_LOGE("extData is nullptr"); return ""; } std::string altNames = reinterpret_cast<char *>(extData->data); std::string hostname = "" + hostName; BIO *bio = BIO_new(BIO_s_file()); if (!bio) { return "bio is null"; } BIO_set_fp(bio, stdout, BIO_NOCLOSE); ASN1_STRING_print(bio, extData); std::vector<std::string> dnsNames = {}; std::vector<std::string> ips = {}; constexpr int DNS_NAME_IDX = 4; constexpr int IP_NAME_IDX = 11; if (!altNames.empty()) { std::vector<std::string> splitAltNames; if (altNames.find('\"') != std::string::npos) { splitAltNames = SplitEscapedAltNames(altNames); } else { splitAltNames = CommonUtils::Split(altNames, SPLIT_ALT_NAMES); } for (auto const &iter : splitAltNames) { if (StartsWith(iter, DNS)) { dnsNames.push_back(iter.substr(DNS_NAME_IDX)); } else if (StartsWith(iter, IP_ADDRESS)) { ips.push_back(iter.substr(IP_NAME_IDX)); } } } std::tuple<bool, std::string> result; CheckIpAndDnsName(hostName, dnsNames, ips, x509Certificates, result); if (!std::get<0>(result)) { return "Hostname/IP does not match certificate's altnames: " + std::get<1>(result); } return HOST_NAME + hostname + ". is cert's CN"; } void TLSSocketServer::RemoveConnect(int socketFd) { std::shared_ptr<Connection> ptrConnection = nullptr; { std::lock_guard<std::mutex> its_lock(connectMutex_); for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) { if (it->second->GetSocketFd() == socketFd) { ptrConnection = it->second; break; } else { ++it; } } } if (ptrConnection != nullptr) { ptrConnection->CallOnCloseCallback(static_cast<unsigned int>(socketFd)); ptrConnection->Close(); } } int TLSSocketServer::RecvRemoteInfo(int socketFd, int index) { { std::lock_guard<std::mutex> its_lock(connectMutex_); for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end();) { if (it->second->GetSocketFd() == socketFd) { char buffer[MAX_BUFFER_SIZE]; if (memset_s(buffer, MAX_BUFFER_SIZE, 0, MAX_BUFFER_SIZE) != EOK) { NETSTACK_LOGE("memcpy_s failed"); break; } int len = it->second->Recv(buffer, MAX_BUFFER_SIZE); NETSTACK_LOGE("revc message is size is %{public}d buffer is %{public}s ", len, buffer); if (len > 0) { Socket::SocketRemoteInfo remoteInfo; remoteInfo.SetSize(strlen(buffer)); it->second->MakeRemoteInfo(remoteInfo); it->second->CallOnMessageCallback(socketFd, buffer, remoteInfo); return len; } #if defined(CROSS_PLATFORM) if (len == 0 && errno == 0) { NETSTACK_LOGI("A client left"); } #endif break; } else { ++it; } } } RemoveConnect(socketFd); DropFdFromPollList(index); return -1; } void TLSSocketServer::Connection::CallOnMessageCallback(int32_t socketFd, const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) { OnMessageCallback CallBackfunc = nullptr; { if (onMessageCallback_) { CallBackfunc = onMessageCallback_; } } if (CallBackfunc) { while (!dataCache_->IsEmpty()) { CacheInfo cache = dataCache_->Get(); CallBackfunc(socketFd, cache.data, cache.remoteInfo); } CallBackfunc(socketFd, data, remoteInfo); } else { CacheInfo cache = {data, remoteInfo}; dataCache_->Set(cache); } } void TLSSocketServer::AddConnect(int socketFd, std::shared_ptr<Connection> connection) { std::lock_guard<std::mutex> its_lock(connectMutex_); clientIdConnections_[connection->GetClientID()] = connection; } void TLSSocketServer::Connection::CallOnCloseCallback(const int32_t socketFd) { OnCloseCallback CallBackfunc = nullptr; { if (onCloseCallback_) { CallBackfunc = onCloseCallback_; } } if (CallBackfunc) { CallBackfunc(socketFd); } } void TLSSocketServer::CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager) { OnConnectCallback CallBackfunc = nullptr; { std::lock_guard<std::mutex> lock(mutex_); if (onConnectCallback_) { CallBackfunc = onConnectCallback_; } } if (CallBackfunc) { CallBackfunc(socketFd, eventManager); } else { NETSTACK_LOGE("CallOnConnectCallback fun === null"); } } bool TLSSocketServer::GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress) { struct sockaddr_storage addr{}; socklen_t addrLen = sizeof(addr); if (getsockname(acceptSockFD, (struct sockaddr *)&addr, &addrLen) < 0) { if (acceptSockFD > 0) { close(acceptSockFD); CallOnErrorCallback(errno, strerror(errno)); return false; } } char ipStr[INET6_ADDRSTRLEN] = {0}; if (addr.ss_family == AF_INET) { auto *addr_in = (struct sockaddr_in *)&addr; inet_ntop(AF_INET, &addr_in->sin_addr, ipStr, sizeof(ipStr)); localAddress.SetFamilyBySaFamily(AF_INET); localAddress.SetRawAddress(ipStr); localAddress.SetPort(ntohs(addr_in->sin_port)); } else if (addr.ss_family == AF_INET6) { auto *addr_in6 = (struct sockaddr_in6 *)&addr; inet_ntop(AF_INET6, &addr_in6->sin6_addr, ipStr, sizeof(ipStr)); localAddress.SetFamilyBySaFamily(AF_INET6); localAddress.SetRawAddress(ipStr); localAddress.SetPort(ntohs(addr_in6->sin6_port)); } return true; } void TLSSocketServer::ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID) { #if !defined(CROSS_PLATFORM) struct sockaddr_in clientAddress; socklen_t clientAddrLength = sizeof(clientAddress); int connectFD = accept(listenSocketFd_, (struct sockaddr *)&clientAddress, &clientAddrLength); if (connectFD < 0) { int resErr = ConvertErrno(); NETSTACK_LOGE("Server accept new client ERROR"); CallOnErrorCallback(resErr, MakeErrnoString()); return; } NETSTACK_LOGI("Server accept new client SUCCESS"); std::shared_ptr<Connection> connection = std::make_shared<Connection>(); Socket::NetAddress netAddress; Socket::NetAddress localAddress; char clientIp[INET6_ADDRSTRLEN] = {0}; inet_ntop(address_.GetSaFamily(), &clientAddress.sin_addr, clientIp, INET_ADDRSTRLEN); int clientPort = ntohs(clientAddress.sin_port); netAddress.SetRawAddress(clientIp); netAddress.SetPort(clientPort); netAddress.SetFamilyBySaFamily(address_.GetSaFamily()); connection->SetAddress(netAddress); if (GetTlsConnectionLocalAddress(connectFD, localAddress)) { connection->SetLocalAddress(localAddress); } SetTlsConnectionSecureOptions(tlsListenOptions, clientID, connectFD, connection); #endif } void TLSSocketServer::SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID, int connectFD, std::shared_ptr<Connection> &connection) { connection->SetClientID(clientID); auto res = connection->TlsAcceptToHost(connectFD, tlsListenOptions); if (!res) { int resErr = ConvertSSLError(connection->GetSSL()); CallOnErrorCallback(resErr, MakeSSLErrorString(resErr)); return; } if (g_userCounter >= USER_LIMIT) { const std::string info = "Too many users!"; connection->Send(info); connection->Close(); NETSTACK_LOGE("Too many users"); close(connectFD); CallOnErrorCallback(-1, "Too many users"); return; } g_userCounter++; fds_[g_userCounter].fd = connectFD; #if defined(CROSS_PLATFORM) fds_[g_userCounter].events = POLLIN | POLLERR; #else fds_[g_userCounter].events = POLLIN | POLLRDHUP | POLLERR; #endif fds_[g_userCounter].revents = 0; AddConnect(connectFD, connection); auto ptrEventManager = std::make_shared<EventManager>(); EventManager::SetValid(ptrEventManager.get()); ptrEventManager->SetData(this); connection->SetEventManager(ptrEventManager); CallOnConnectCallback(clientID, ptrEventManager); NETSTACK_LOGI("New client come in, fd is %{public}d", connectFD); } void TLSSocketServer::InitPollList(int &listendFd) { for (int i = 1; i <= USER_LIMIT; ++i) { fds_[i].fd = -1; fds_[i].events = 0; } fds_[0].fd = listendFd; fds_[0].events = POLLIN | POLLERR; fds_[0].revents = 0; } void TLSSocketServer::DropFdFromPollList(int &fd_index) { if (g_userCounter < 0) { NETSTACK_LOGE("g_userCounter = %{public}d", g_userCounter); return; } fds_[fd_index].fd = fds_[g_userCounter].fd; fds_[g_userCounter].fd = -1; fds_[g_userCounter].events = 0; fd_index--; g_userCounter--; NETSTACK_LOGE("CallOnConnectCallback g_userCounter = %{public}d", g_userCounter); } void TLSSocketServer::PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions) { int on = 1; isRunning_ = true; ioctl(listenSocketFd_, FIONBIO, (char *)&on); NETSTACK_LOGE("PollThread start working %{public}d", isRunning_); std::thread thread_([this, tlsOption = tlsListenOptions]() { #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM) pthread_setname_np(TLS_SOCKET_SERVER_READ); #else pthread_setname_np(pthread_self(), TLS_SOCKET_SERVER_READ); #endif InitPollList(listenSocketFd_); int clientId = 0; while (isRunning_) { int ret = poll(fds_, g_userCounter + 1, POLL_WAIT_TIME); if (ret < 0) { int resErr = ConvertErrno(); NETSTACK_LOGE("Poll ERROR"); CallOnErrorCallback(resErr, MakeErrnoString()); break; } if (ret == 0) { continue; } for (int i = 0; i < g_userCounter + 1; ++i) { if ((fds_[i].fd == listenSocketFd_) && (static_cast<uint16_t>(fds_[i].revents) & POLLIN)) { ProcessTcpAccept(tlsOption, ++clientId); #if !defined(CROSS_PLATFORM) } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLRDHUP) || (static_cast<uint16_t>(fds_[i].revents) & POLLERR)) { #else } else if ((static_cast<uint16_t>(fds_[i].revents) & POLLERR)) { #endif RemoveConnect(fds_[i].fd); DropFdFromPollList(i); NETSTACK_LOGI("A client left"); } else if (static_cast<uint16_t>(fds_[i].revents) & POLLIN) { RecvRemoteInfo(fds_[i].fd, i); } } } }); thread_.detach(); } std::shared_ptr<TLSSocketServer::Connection> TLSSocketServer::GetConnectionByClientEventManager( const EventManager *eventManager) { std::lock_guard<std::mutex> its_lock(connectMutex_); auto it = std::find_if(clientIdConnections_.begin(), clientIdConnections_.end(), [eventManager](const auto& pair) { return pair.second->GetEventManager().get() == eventManager; }); if (it == clientIdConnections_.end()) { return nullptr; } return it->second; } void TLSSocketServer::CloseConnectionByEventManager(EventManager *eventManager) { std::shared_ptr<Connection> ptrConnection = GetConnectionByClientEventManager(eventManager); if (ptrConnection != nullptr) { ptrConnection->Close(); } } void TLSSocketServer::DeleteConnectionByEventManager(EventManager *eventManager) { std::lock_guard<std::mutex> its_lock(connectMutex_); for (auto it = clientIdConnections_.begin(); it != clientIdConnections_.end(); ++it) { if (it->second->GetEventManager().get() == eventManager) { it = clientIdConnections_.erase(it); break; } } } } // namespace TlsSocketServer } // namespace NetStack } // namespace OHOS