1 /* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 18 19 #include "event_manager.h" 20 #include "extra_options_base.h" 21 #include "net_address.h" 22 #include "socket_error.h" 23 #include "socket_remote_info.h" 24 #include "socket_state_base.h" 25 #include "tcp_connect_options.h" 26 #include "tcp_extra_options.h" 27 #include "tcp_send_options.h" 28 #include "tls.h" 29 #include "tls_certificate.h" 30 #include "tls_configuration.h" 31 #include "tls_context_server.h" 32 #include "tls_key.h" 33 #include "tls_socket.h" 34 #include <any> 35 #include <condition_variable> 36 #include <cstring> 37 #include <functional> 38 #include <map> 39 #include <poll.h> 40 #include <thread> 41 #include <tuple> 42 #include <unistd.h> 43 #include <vector> 44 45 namespace OHOS { 46 namespace NetStack { 47 namespace TlsSocketServer { 48 constexpr int USER_LIMIT = 10; 49 struct CacheInfo { 50 std::string data; 51 Socket::SocketRemoteInfo remoteInfo; 52 }; 53 using OnMessageCallback = 54 std::function<void(const int &socketFd, const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 55 using OnCloseCallback = std::function<void(const int &socketFd)>; 56 using OnConnectCallback = std::function<void(const int &socketFd, std::shared_ptr<EventManager> eventManager)>; 57 using ListenCallback = std::function<void(int32_t errorNumber)>; 58 class TLSServerSendOptions { 59 public: 60 /** 61 * Set the socket ID to be transmitted 62 * @param socketFd Communication descriptor 63 */ 64 void SetSocket(const int &socketFd); 65 66 /** 67 * Set the data to send 68 * @param data Send data 69 */ 70 void SetSendData(const std::string &data); 71 72 /** 73 * Get the socket ID 74 * @return Gets the communication descriptor 75 */ 76 [[nodiscard]] const int &GetSocket() const; 77 78 /** 79 * Gets the data sent 80 * @return Send data 81 */ 82 [[nodiscard]] const std::string &GetSendData() const; 83 84 private: 85 int socketFd_; 86 std::string data_; 87 }; 88 89 class TLSSocketServer { 90 public: 91 TLSSocketServer(const TLSSocketServer &) = delete; 92 TLSSocketServer(TLSSocketServer &&) = delete; 93 94 TLSSocketServer &operator=(const TLSSocketServer &) = delete; 95 TLSSocketServer &operator=(TLSSocketServer &&) = delete; 96 97 TLSSocketServer() = default; 98 ~TLSSocketServer(); 99 100 /** 101 * Create sockets, bind and listen waiting for clients to connect 102 * @param tlsListenOptions Bind the listening connection configuration 103 * @param callback callback to the caller if bind ok or not 104 */ 105 void Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback); 106 107 /** 108 * Send data through an established encrypted connection 109 * @param data data sent over an established encrypted connection 110 * @return whether the data is successfully sent to the server 111 */ 112 bool Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback); 113 114 /** 115 * Disconnect by releasing the socket when communicating 116 * @param socketFd The socket ID of the client 117 * @param callback callback to the caller 118 */ 119 void Close(const int socketFd, const TlsSocket::CloseCallback &callback); 120 121 /** 122 * Disconnect by releasing the socket when communicating 123 * @param callback callback to the caller 124 */ 125 void Stop(const TlsSocket::CloseCallback &callback); 126 127 /** 128 * Get the peer network address 129 * @param socketFd The socket ID of the client 130 * @param callback callback to the caller 131 */ 132 void GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback); 133 134 /** 135 * Get the peer network address 136 * @param socketFd The socket ID of the client 137 * @param callback callback to the caller 138 */ 139 void GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback); 140 141 /** 142 * Get the status of the current socket 143 * @param callback callback to the caller 144 */ 145 void GetState(const TlsSocket::GetStateCallback &callback); 146 147 /** 148 * Gets or sets the options associated with the current socket 149 * @param tcpExtraOptions options associated with the current socket 150 * @param callback callback to the caller 151 */ 152 bool SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, 153 const TlsSocket::SetExtraOptionsCallback &callback); 154 155 /** 156 * Get a local digital certificate 157 * @param callback callback to the caller 158 */ 159 void GetCertificate(const TlsSocket::GetCertificateCallback &callback); 160 161 /** 162 * Get the peer digital certificate 163 * @param socketFd The socket ID of the client 164 * @param needChain need chain 165 * @param callback callback to the caller 166 */ 167 void GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback); 168 169 /** 170 * Obtain the protocol used in communication 171 * @param callback callback to the caller 172 */ 173 void GetProtocol(const TlsSocket::GetProtocolCallback &callback); 174 175 /** 176 * Obtain the cipher suite used in communication 177 * @param socketFd The socket ID of the client 178 * @param callback callback to the caller 179 */ 180 void GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback); 181 182 /** 183 * Obtain the encryption algorithm used in the communication process 184 * @param socketFd The socket ID of the client 185 * @param callback callback to the caller 186 */ 187 void GetSignatureAlgorithms(const int socketFd, const TlsSocket::GetSignatureAlgorithmsCallback &callback); 188 189 /** 190 * Register the callback that is called when the connection is disconnected 191 * @param onCloseCallback callback invoked when disconnected 192 */ 193 194 /** 195 * Register the callback that is called when the connection is established 196 * @param onConnectCallback callback invoked when connection is established 197 */ 198 void OnConnect(const OnConnectCallback &onConnectCallback); 199 200 /** 201 * Register the callback that is called when an error occurs 202 * @param onErrorCallback callback invoked when an error occurs 203 */ 204 void OnError(const TlsSocket::OnErrorCallback &onErrorCallback); 205 206 /** 207 * Off Connect 208 */ 209 void OffConnect(); 210 211 /** 212 * Off Error 213 */ 214 void OffError(); 215 216 /** 217 * Get the socket file description of the server 218 */ 219 int GetListenSocketFd(); 220 221 /** 222 * Set the current socket file description address of the server 223 */ 224 void SetLocalAddress(const Socket::NetAddress &address); 225 226 /** 227 * Get the current socket file description address of the server 228 */ 229 Socket::NetAddress GetLocalAddress(); 230 231 public: 232 class Connection : public std::enable_shared_from_this<Connection> { 233 public: 234 ~Connection(); 235 /** 236 * Establish an encrypted accept on the specified socket 237 * @param sock socket for establishing encrypted connection 238 * @param options some options required during tls accept 239 * @return whether the encrypted accept is successfully established 240 */ 241 bool TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options); 242 243 /** 244 * Set the configuration items for establishing encrypted connections 245 * @param config configuration item when establishing encrypted connection 246 */ 247 void SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config); 248 249 /** 250 * Set address information 251 */ 252 void SetAddress(const Socket::NetAddress address); 253 254 /** 255 * Set local address information 256 */ 257 void SetLocalAddress(const Socket::NetAddress address); 258 259 /** 260 * Send data through an established encrypted connection 261 * @param data data sent over an established encrypted connection 262 * @return whether the data is successfully sent to the server 263 */ 264 bool Send(const std::string &data); 265 266 /** 267 * Receive the data sent by the server through the established encrypted connection 268 * @param buffer receive the data sent by the server 269 * @param maxBufferSize the size of the data received from the server 270 * @return whether the data sent by the server is successfully received 271 */ 272 int Recv(char *buffer, int maxBufferSize); 273 274 /** 275 * Disconnect encrypted connection 276 * @return whether the encrypted connection was successfully disconnected 277 */ 278 bool Close(); 279 280 /** 281 * Set the application layer negotiation protocol in the encrypted communication process 282 * @param alpnProtocols application layer negotiation protocol 283 * @return set whether the application layer negotiation protocol is successful during encrypted communication 284 */ 285 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 286 287 /** 288 * Storage of server communication related network information 289 * @param remoteInfo communication related network information 290 */ 291 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 292 293 /** 294 * Get configuration options for encrypted communication process 295 * @return configuration options for encrypted communication processes 296 */ 297 [[nodiscard]] TlsSocket::TLSConfiguration GetTlsConfiguration() const; 298 299 /** 300 * Obtain the cipher suite during encrypted communication 301 * @return crypto suite used in encrypted communication 302 */ 303 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 304 305 /** 306 * Obtain the peer certificate used in encrypted communication 307 * @return peer certificate used in encrypted communication 308 */ 309 [[nodiscard]] std::string GetRemoteCertificate() const; 310 311 /** 312 * Obtain the peer certificate used in encrypted communication 313 * @return peer certificate serialization data used in encrypted communication 314 */ 315 [[nodiscard]] const TlsSocket::X509CertRawData &GetRemoteCertRawData() const; 316 317 /** 318 * Obtain the certificate used in encrypted communication 319 * @return certificate serialization data used in encrypted communication 320 */ 321 [[nodiscard]] const TlsSocket::X509CertRawData &GetCertificate() const; 322 323 /** 324 * Get the encryption algorithm used in encrypted communication 325 * @return encryption algorithm used in encrypted communication 326 */ 327 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 328 329 /** 330 * Obtain the communication protocol used in encrypted communication 331 * @return communication protocol used in encrypted communication 332 */ 333 [[nodiscard]] std::string GetProtocol() const; 334 335 /** 336 * Set the information about the shared signature algorithm supported by peers during encrypted communication 337 * @return information about peer supported shared signature algorithms 338 */ 339 [[nodiscard]] bool SetSharedSigals(); 340 341 /** 342 * Obtain the ssl used in encrypted communication 343 * @return SSL used in encrypted communication 344 */ 345 [[nodiscard]] ssl_st *GetSSL() const; 346 347 /** 348 * Get address information 349 * @return Returns the address information of the remote client 350 */ 351 [[nodiscard]] Socket::NetAddress GetAddress() const; 352 353 /** 354 * Get local address information 355 * @return Returns the address information of the local accept connect 356 */ 357 [[nodiscard]] Socket::NetAddress GetLocalAddress() const; 358 359 /** 360 * Get address information 361 * @return Returns the address information of the remote client 362 */ 363 [[nodiscard]] int GetSocketFd() const; 364 365 /** 366 * Get EventManager information 367 * @return Returns the address information of the remote client 368 */ 369 [[nodiscard]] std::shared_ptr<EventManager> GetEventManager() const; 370 371 void OnMessage(const OnMessageCallback &onMessageCallback); 372 /** 373 * Unregister the callback which is called when message is received 374 */ 375 void OffMessage(); 376 377 void CallOnMessageCallback(int32_t socketFd, const std::string &data, 378 const Socket::SocketRemoteInfo &remoteInfo); 379 380 void SetEventManager(std::shared_ptr<EventManager> eventManager); 381 382 void SetClientID(int32_t clientID); 383 384 [[nodiscard]] int GetClientID(); 385 386 void CallOnCloseCallback(const int32_t socketFd); 387 void OnClose(const OnCloseCallback &onCloseCallback); 388 OnCloseCallback onCloseCallback_; 389 390 /** 391 * Off Close 392 */ 393 void OffClose(); 394 395 /** 396 * Register the callback that is called when an error occurs 397 * @param onErrorCallback callback invoked when an error occurs 398 */ 399 void OnError(const TlsSocket::OnErrorCallback &onErrorCallback); 400 /** 401 * Off Error 402 */ 403 void OffError(); 404 405 void CallOnErrorCallback(int32_t err, const std::string &errString); 406 407 class DataCache { 408 public: Get()409 CacheInfo Get() 410 { 411 std::lock_guard l(mutex_); 412 CacheInfo cache = cacheDeque_.front(); 413 cacheDeque_.pop_front(); 414 return cache; 415 } Set(const CacheInfo & data)416 void Set(const CacheInfo &data) 417 { 418 std::lock_guard l(mutex_); 419 cacheDeque_.emplace_back(data); 420 } IsEmpty()421 bool IsEmpty() 422 { 423 std::lock_guard l(mutex_); 424 return cacheDeque_.empty(); 425 } 426 427 private: 428 std::deque<CacheInfo> cacheDeque_; 429 std::mutex mutex_; 430 }; 431 432 TlsSocket::OnErrorCallback onErrorCallback_; 433 434 private: 435 bool StartTlsAccept(const TlsSocket::TLSConnectOptions &options); 436 bool CreatTlsContext(); 437 bool StartShakingHands(const TlsSocket::TLSConnectOptions &options); 438 bool GetRemoteCertificateFromPeer(); 439 bool SetRemoteCertRawData(); 440 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 441 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 442 const X509 *x509Certificates); 443 444 private: 445 ssl_st *ssl_ = nullptr; 446 X509 *peerX509_ = nullptr; 447 int32_t socketFd_ = 0; 448 449 TlsSocket::TLSContextServer tlsContext_; 450 TlsSocket::TLSConfiguration connectionConfiguration_; 451 Socket::NetAddress address_; 452 Socket::NetAddress localAddress_; 453 TlsSocket::X509CertRawData remoteRawData_; 454 455 std::string hostName_; 456 std::string remoteCert_; 457 std::string keyPass_; 458 459 std::vector<std::string> signatureAlgorithms_; 460 std::unique_ptr<TlsSocket::TLSContextServer> tlsContextServerPointer_ = nullptr; 461 462 std::shared_ptr<EventManager> eventManager_ = nullptr; 463 int32_t clientID_ = 0; 464 OnMessageCallback onMessageCallback_; 465 std::shared_ptr<DataCache> dataCache_ = std::make_shared<DataCache>(); 466 }; 467 468 private: 469 void SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config); 470 int RecvRemoteInfo(int socketFd, int index); 471 void RemoveConnect(int socketFd); 472 void AddConnect(int socketFd, std::shared_ptr<Connection> connection); 473 void CallListenCallback(int32_t err, ListenCallback callback); 474 void CallOnErrorCallback(int32_t err, const std::string &errString); 475 476 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, TlsSocket::GetStateCallback callback); 477 void CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager); 478 void CallSendCallback(int32_t err, TlsSocket::SendCallback callback); 479 bool ExecBind(const Socket::NetAddress &address, const ListenCallback &callback); 480 void ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback); 481 void MakeIpSocket(sa_family_t family); 482 void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 483 socklen_t *len); 484 static constexpr const size_t MAX_ERROR_LEN = 128; 485 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 486 487 void PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions); 488 489 private: 490 std::mutex mutex_; 491 std::mutex connectMutex_; 492 int listenSocketFd_ = -1; 493 Socket::NetAddress address_; 494 Socket::NetAddress localAddress_; 495 496 std::map<int, std::shared_ptr<Connection>> clientIdConnections_; 497 TlsSocket::TLSConfiguration TLSServerConfiguration_; 498 499 OnConnectCallback onConnectCallback_; 500 TlsSocket::OnErrorCallback onErrorCallback_; 501 502 bool GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress); 503 void ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientId); 504 void DropFdFromPollList(int &fd_index); 505 void InitPollList(int &listendFd); 506 507 struct pollfd fds_[USER_LIMIT + 1]; 508 509 bool isRunning_; 510 511 public: 512 std::shared_ptr<Connection> GetConnectionByClientID(int clientid); 513 int GetConnectionClientCount(); 514 515 std::shared_ptr<Connection> GetConnectionByClientEventManager(const EventManager *eventManager); 516 void CloseConnectionByEventManager(EventManager *eventManager); 517 void DeleteConnectionByEventManager(EventManager *eventManager); 518 void SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID, 519 int connectFD, std::shared_ptr<Connection> &connection); 520 }; 521 } // namespace TlsSocketServer 522 } // namespace NetStack 523 } // namespace OHOS 524 525 #endif // COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 526