1 /* 2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H 18 19 #include <any> 20 #include <condition_variable> 21 #include <cstring> 22 #include <functional> 23 #include <map> 24 #include <thread> 25 #include <tuple> 26 #include <unistd.h> 27 #include <vector> 28 29 #include "extra_options_base.h" 30 #include "net_address.h" 31 #include "socket_error.h" 32 #include "socket_remote_info.h" 33 #include "socket_state_base.h" 34 #include "tcp_connect_options.h" 35 #include "tcp_extra_options.h" 36 #include "tcp_send_options.h" 37 #include "tls.h" 38 #include "tls_certificate.h" 39 #include "tls_configuration.h" 40 #include "tls_context.h" 41 #include "tls_key.h" 42 43 namespace OHOS { 44 namespace NetStack { 45 namespace TlsSocket { 46 47 using BindCallback = std::function<void(int32_t errorNumber)>; 48 using ConnectCallback = std::function<void(int32_t errorNumber)>; 49 using SendCallback = std::function<void(int32_t errorNumber)>; 50 using CloseCallback = std::function<void(int32_t errorNumber)>; 51 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 52 using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 53 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>; 54 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>; 55 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 56 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 57 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>; 58 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>; 59 using GetSignatureAlgorithmsCallback = 60 std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>; 61 62 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 63 using OnConnectCallback = std::function<void(void)>; 64 using OnCloseCallback = std::function<void(void)>; 65 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>; 66 67 using CheckServerIdentity = 68 std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>; 69 70 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1"; 71 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2"; 72 73 constexpr size_t MAX_ERR_LEN = 1024; 74 75 /** 76 * Parameters required during communication 77 */ 78 class TLSSecureOptions { 79 public: 80 TLSSecureOptions() = default; 81 ~TLSSecureOptions() = default; 82 83 TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions); 84 TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions); 85 /** 86 * Set root CA Chain to verify the server cert 87 * @param caChain root certificate chain used to validate server certificates 88 */ 89 void SetCaChain(const std::vector<std::string> &caChain); 90 91 /** 92 * Set digital certificate for server verification 93 * @param cert digital certificate sent to the server to verify validity 94 */ 95 void SetCert(const std::string &cert); 96 97 /** 98 * Set key to decrypt server data 99 * @param keyChain key used to decrypt server data 100 */ 101 void SetKey(const SecureData &key); 102 103 /** 104 * Set the password to read the private key 105 * @param keyPass read the password of the private key 106 */ 107 void SetKeyPass(const SecureData &keyPass); 108 109 /** 110 * Set the protocol used in communication 111 * @param protocolChain protocol version number used 112 */ 113 void SetProtocolChain(const std::vector<std::string> &protocolChain); 114 115 /** 116 * Whether the peer cipher suite is preferred for communication 117 * @param useRemoteCipherPrefer whether the peer cipher suite is preferred 118 */ 119 void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer); 120 121 /** 122 * Encryption algorithm used in communication 123 * @param signatureAlgorithms encryption algorithm e.g: rsa 124 */ 125 void SetSignatureAlgorithms(const std::string &signatureAlgorithms); 126 127 /** 128 * Crypto suite used in communication 129 * @param cipherSuite cipher suite e.g:AES256-SHA256 130 */ 131 void SetCipherSuite(const std::string &cipherSuite); 132 133 /** 134 * Set a revoked certificate 135 * @param crlChain certificate Revocation List 136 */ 137 void SetCrlChain(const std::vector<std::string> &crlChain); 138 139 /** 140 * Get root CA Chain to verify the server cert 141 * @return root CA chain 142 */ 143 [[nodiscard]] const std::vector<std::string> &GetCaChain() const; 144 145 /** 146 * Obtain a certificate to send to the server for checking 147 * @return digital certificate obtained 148 */ 149 [[nodiscard]] const std::string &GetCert() const; 150 151 /** 152 * Obtain the private key in the communication process 153 * @return private key during communication 154 */ 155 [[nodiscard]] const SecureData &GetKey() const; 156 157 /** 158 * Get the password to read the private key 159 * @return read the password of the private key 160 */ 161 [[nodiscard]] const SecureData &GetKeyPass() const; 162 163 /** 164 * Get the protocol of the communication process 165 * @return protocol of communication process 166 */ 167 [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const; 168 169 /** 170 * Is the remote cipher suite being used for communication 171 * @return is use Remote Cipher Prefer 172 */ 173 [[nodiscard]] bool UseRemoteCipherPrefer() const; 174 175 /** 176 * Obtain the encryption algorithm used in the communication process 177 * @return encryption algorithm used in communication 178 */ 179 [[nodiscard]] const std::string &GetSignatureAlgorithms() const; 180 181 /** 182 * Obtain the cipher suite used in communication 183 * @return crypto suite used in communication 184 */ 185 [[nodiscard]] const std::string &GetCipherSuite() const; 186 187 /** 188 * Get revoked certificate chain 189 * @return revoked certificate chain 190 */ 191 [[nodiscard]] const std::vector<std::string> &GetCrlChain() const; 192 193 void SetVerifyMode(VerifyMode verifyMode); 194 195 [[nodiscard]] VerifyMode GetVerifyMode() const; 196 197 private: 198 std::vector<std::string> caChain_; 199 std::string cert_; 200 SecureData key_; 201 SecureData keyPass_; 202 std::vector<std::string> protocolChain_; 203 bool useRemoteCipherPrefer_ = false; 204 std::string signatureAlgorithms_; 205 std::string cipherSuite_; 206 std::vector<std::string> crlChain_; 207 VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE; 208 }; 209 210 /** 211 * Some options required during tls connection 212 */ 213 class TLSConnectOptions { 214 public: 215 friend class TLSSocketExec; 216 /** 217 * Communication parameters required for connection establishment 218 * @param address communication parameters during connection 219 */ 220 void SetNetAddress(const Socket::NetAddress &address); 221 222 /** 223 * Parameters required during communication 224 * @param tlsSecureOptions certificate and other relevant parameters 225 */ 226 void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions); 227 228 /** 229 * Set the callback function to check the validity of the server 230 * @param checkServerIdentity callback function passed in by API caller 231 */ 232 void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity); 233 234 /** 235 * Set application layer protocol negotiation 236 * @param alpnProtocols application layer protocol negotiation 237 */ 238 void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 239 240 /** 241 * Set whether to skip remote validation 242 * @param skipRemoteValidation flag to choose whether to skip validation 243 */ 244 void SetSkipRemoteValidation(bool skipRemoteValidation); 245 246 /** 247 * Obtain the network address of the communication process 248 * @return network address 249 */ 250 [[nodiscard]] Socket::NetAddress GetNetAddress() const; 251 252 /** 253 * Obtain the parameters required in the communication process 254 * @return certificate and other relevant parameters 255 */ 256 [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const; 257 258 /** 259 * Get the check server ID callback function passed in by the API caller 260 * @return check the server identity callback function 261 */ 262 [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const; 263 264 /** 265 * Obtain the application layer protocol negotiation in the communication process 266 * @return application layer protocol negotiation 267 */ 268 [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const; 269 270 /** 271 * Get the choice of whether to skip remote validaion 272 * @return skipRemoteValidaion result 273 */ 274 [[nodiscard]] bool GetSkipRemoteValidation() const; 275 276 void SetHostName(const std::string &hostName); 277 [[nodiscard]] std::string GetHostName() const; 278 279 private: 280 Socket::NetAddress address_; 281 TLSSecureOptions tlsSecureOptions_; 282 CheckServerIdentity checkServerIdentity_; 283 std::vector<std::string> alpnProtocols_; 284 bool skipRemoteValidation_ = false; 285 std::string hostName_; 286 }; 287 288 /** 289 * TLS socket interface class 290 */ 291 class TLSSocket { 292 public: 293 TLSSocket(const TLSSocket &) = delete; 294 TLSSocket(TLSSocket &&) = delete; 295 296 TLSSocket &operator=(const TLSSocket &) = delete; 297 TLSSocket &operator=(TLSSocket &&) = delete; 298 299 TLSSocket() = default; 300 ~TLSSocket() = default; 301 TLSSocket(int sockFd)302 explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {} 303 304 /** 305 * Create a socket and bind to the address specified by address 306 * @param address ip address 307 * @param callback callback to the caller if bind ok or not 308 */ 309 void Bind(Socket::NetAddress &address, const BindCallback &callback); 310 311 /** 312 * Establish a secure connection based on the created socket 313 * @param tlsConnectOptions some options required during tls connection 314 * @param callback callback to the caller if connect ok or not 315 */ 316 void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback); 317 318 /** 319 * Send data based on the created socket 320 * @param tcpSendOptions some options required during tcp data transmission 321 * @param callback callback to the caller if send ok or not 322 */ 323 void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback); 324 325 /** 326 * Disconnect by releasing the socket when communicating 327 * @param callback callback to the caller 328 */ 329 void Close(const CloseCallback &callback); 330 331 /** 332 * Get the peer network address 333 * @param callback callback to the caller 334 */ 335 void GetRemoteAddress(const GetRemoteAddressCallback &callback); 336 337 /** 338 * Get the status of the current socket 339 * @param callback callback to the caller 340 */ 341 void GetState(const GetStateCallback &callback); 342 343 /** 344 * Gets or sets the options associated with the current socket 345 * @param tcpExtraOptions options associated with the current socket 346 * @param callback callback to the caller 347 */ 348 void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback); 349 350 /** 351 * Get a local digital certificate 352 * @param callback callback to the caller 353 */ 354 void GetCertificate(const GetCertificateCallback &callback); 355 356 /** 357 * Get the peer digital certificate 358 * @param needChain need chain 359 * @param callback callback to the caller 360 */ 361 void GetRemoteCertificate(const GetRemoteCertificateCallback &callback); 362 363 /** 364 * Obtain the protocol used in communication 365 * @param callback callback to the caller 366 */ 367 void GetProtocol(const GetProtocolCallback &callback); 368 369 /** 370 * Obtain the cipher suite used in communication 371 * @param callback callback to the caller 372 */ 373 void GetCipherSuite(const GetCipherSuiteCallback &callback); 374 375 /** 376 * Obtain the encryption algorithm used in the communication process 377 * @param callback callback to the caller 378 */ 379 void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback); 380 381 /** 382 * Register a callback which is called when message is received 383 * @param onMessageCallback callback which is called when message is received 384 */ 385 void OnMessage(const OnMessageCallback &onMessageCallback); 386 387 /** 388 * Register the callback that is called when the connection is established 389 * @param onConnectCallback callback invoked when connection is established 390 */ 391 void OnConnect(const OnConnectCallback &onConnectCallback); 392 393 /** 394 * Register the callback that is called when the connection is disconnected 395 * @param onCloseCallback callback invoked when disconnected 396 */ 397 void OnClose(const OnCloseCallback &onCloseCallback); 398 399 /** 400 * Register the callback that is called when an error occurs 401 * @param onErrorCallback callback invoked when an error occurs 402 */ 403 void OnError(const OnErrorCallback &onErrorCallback); 404 405 /** 406 * Unregister the callback which is called when message is received 407 */ 408 void OffMessage(); 409 410 /** 411 * Off Connect 412 */ 413 void OffConnect(); 414 415 /** 416 * Off Close 417 */ 418 void OffClose(); 419 420 /** 421 * Off Error 422 */ 423 void OffError(); 424 425 /** 426 * Get the socket file description of the server 427 */ 428 int GetSocketFd(); 429 430 /** 431 * Set the current socket file description address of the server 432 */ 433 void SetLocalAddress(const Socket::NetAddress &address); 434 435 /** 436 * Get the current socket file description address of the server 437 */ 438 Socket::NetAddress GetLocalAddress(); 439 440 bool GetCloseState(); 441 442 void SetCloseState(bool flag); 443 444 std::mutex &GetCloseLock(); 445 private: 446 class TLSSocketInternal final { 447 public: 448 TLSSocketInternal() = default; 449 ~TLSSocketInternal() = default; 450 451 /** 452 * Establish an encrypted connection on the specified socket 453 * @param sock socket for establishing encrypted connection 454 * @param options some options required during tls connection 455 * @param isExtSock socket fd is originated from external source when constructing tls socket 456 * @return whether the encrypted connection is successfully established 457 */ 458 bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock); 459 460 /** 461 * Set the configuration items for establishing encrypted connections 462 * @param config configuration item when establishing encrypted connection 463 */ 464 void SetTlsConfiguration(const TLSConnectOptions &config); 465 466 /** 467 * Send data through an established encrypted connection 468 * @param data data sent over an established encrypted connection 469 * @return whether the data is successfully sent to the server 470 */ 471 bool Send(const std::string &data); 472 473 /** 474 * Receive the data sent by the server through the established encrypted connection 475 * @param buffer receive the data sent by the server 476 * @param maxBufferSize the size of the data received from the server 477 * @return whether the data sent by the server is successfully received 478 */ 479 int Recv(char *buffer, int maxBufferSize); 480 481 /** 482 * Disconnect encrypted connection 483 * @return whether the encrypted connection was successfully disconnected 484 */ 485 bool Close(); 486 487 /** 488 * Set the application layer negotiation protocol in the encrypted communication process 489 * @param alpnProtocols application layer negotiation protocol 490 * @return set whether the application layer negotiation protocol is successful during encrypted communication 491 */ 492 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 493 494 /** 495 * Storage of server communication related network information 496 * @param remoteInfo communication related network information 497 */ 498 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 499 500 /** 501 * convert the code to ssl error code 502 * @return the value for ssl error code. 503 */ 504 int ConvertSSLError(void); 505 506 /** 507 * Get configuration options for encrypted communication process 508 * @return configuration options for encrypted communication processes 509 */ 510 [[nodiscard]] TLSConfiguration GetTlsConfiguration() const; 511 512 /** 513 * Obtain the cipher suite during encrypted communication 514 * @return crypto suite used in encrypted communication 515 */ 516 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 517 518 /** 519 * Obtain the peer certificate used in encrypted communication 520 * @return peer certificate used in encrypted communication 521 */ 522 [[nodiscard]] std::string GetRemoteCertificate() const; 523 524 /** 525 * Obtain the peer certificate used in encrypted communication 526 * @return peer certificate serialization data used in encrypted communication 527 */ 528 [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const; 529 530 /** 531 * Obtain the certificate used in encrypted communication 532 * @return certificate serialization data used in encrypted communication 533 */ 534 [[nodiscard]] const X509CertRawData &GetCertificate() const; 535 536 /** 537 * Get the encryption algorithm used in encrypted communication 538 * @return encryption algorithm used in encrypted communication 539 */ 540 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 541 542 /** 543 * Obtain the communication protocol used in encrypted communication 544 * @return communication protocol used in encrypted communication 545 */ 546 [[nodiscard]] std::string GetProtocol() const; 547 548 /** 549 * Set the information about the shared signature algorithm supported by peers during encrypted communication 550 * @return information about peer supported shared signature algorithms 551 */ 552 [[nodiscard]] bool SetSharedSigals(); 553 554 /** 555 * Obtain the ssl used in encrypted communication 556 * @return SSL used in encrypted communication 557 */ 558 [[nodiscard]] ssl_st *GetSSL(); 559 560 private: 561 bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd); 562 bool StartTlsConnected(const TLSConnectOptions &options); 563 bool CreatTlsContext(); 564 bool StartShakingHands(const TLSConnectOptions &options); 565 bool GetRemoteCertificateFromPeer(); 566 bool SetRemoteCertRawData(); 567 bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize); 568 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 569 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 570 const X509 *x509Certificates); 571 572 private: 573 std::mutex mutexForSsl_; 574 ssl_st *ssl_ = nullptr; 575 X509 *peerX509_ = nullptr; 576 uint16_t port_ = 0; 577 sa_family_t family_ = 0; 578 int32_t socketDescriptor_ = 0; 579 580 TLSContext tlsContext_; 581 TLSConfiguration configuration_; 582 Socket::NetAddress address_; 583 X509CertRawData remoteRawData_; 584 585 std::string hostName_; 586 std::string remoteCert_; 587 588 std::vector<std::string> signatureAlgorithms_; 589 std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr; 590 }; 591 592 private: 593 TLSSocketInternal tlsSocketInternal_; 594 595 static std::string MakeAddressString(sockaddr *addr); 596 597 static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 598 socklen_t *len); 599 600 void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo); 601 void CallOnConnectCallback(); 602 void CallOnCloseCallback(); 603 void CallOnErrorCallback(int32_t err, const std::string &errString); 604 605 void CallBindCallback(int32_t err, BindCallback callback); 606 void CallConnectCallback(int32_t err, ConnectCallback callback); 607 void CallSendCallback(int32_t err, SendCallback callback); 608 void CallCloseCallback(int32_t err, CloseCallback callback); 609 void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address, 610 GetRemoteAddressCallback callback); 611 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback); 612 void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback); 613 void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback); 614 void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert, 615 GetRemoteCertificateCallback callback); 616 void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback); 617 void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite, 618 GetCipherSuiteCallback callback); 619 void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms, 620 GetSignatureAlgorithmsCallback callback); 621 622 int ReadMessage(); 623 void StartReadMessage(); 624 625 void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback); 626 void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback); 627 628 [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const; 629 [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const; 630 631 void MakeIpSocket(sa_family_t family); 632 633 template<class T> DealCallback(int32_t err,T & callback)634 void DealCallback(int32_t err, T &callback) 635 { 636 T func = nullptr; 637 { 638 std::lock_guard<std::mutex> lock(mutex_); 639 if (callback) { 640 func = callback; 641 } 642 } 643 644 if (func) { 645 func(err); 646 } 647 } 648 649 private: 650 static constexpr const size_t MAX_ERROR_LEN = 128; 651 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 652 653 OnMessageCallback onMessageCallback_; 654 OnConnectCallback onConnectCallback_; 655 OnCloseCallback onCloseCallback_; 656 OnErrorCallback onErrorCallback_; 657 658 std::mutex mutex_; 659 std::mutex recvMutex_; 660 std::mutex cvMutex_; 661 bool isRunning_ = false; 662 bool isRunOver_ = true; 663 std::condition_variable cvSslFree_; 664 int sockFd_ = -1; 665 bool isExtSock_ = false; 666 Socket::NetAddress localAddress_; 667 bool isClosed = false; 668 std::mutex mutexForClose_; 669 }; 670 } // namespace TlsSocket 671 } // namespace NetStack 672 } // namespace OHOS 673 674 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H 675