1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <fstream>
17 #include <gtest/gtest.h>
18 #include <iostream>
19 #include <openssl/rsa.h>
20 #include <openssl/ssl.h>
21 #include <sstream>
22 #include <string>
23 #include <string_view>
24 #include <unistd.h>
25 #include <vector>
26 
27 #include "net_address.h"
28 #include "secure_data.h"
29 #include "socket_error.h"
30 #include "socket_state_base.h"
31 #include "tls.h"
32 #include "tls_certificate.h"
33 #include "tls_configuration.h"
34 #include "tls_key.h"
35 #include "tls_socket_server.h"
36 #include "tls_socket.h"
37 
38 namespace OHOS {
39 namespace NetStack {
40 namespace TlsSocketServer {
41 namespace {
42 const std::string_view CA_DER = "/data/ClientCert/ca.crt";
43 const std::string_view IP_ADDRESS = "/data/Ip/address.txt";
44 const std::string_view PORT = "/data/Ip/port.txt";
45 
CheckCaFileExistence(const char * function)46 inline bool CheckCaFileExistence(const char *function)
47 {
48     if (access(CA_DER.data(), 0)) {
49         std::cout << "CA file does not exist! (" << function << ")";
50         return false;
51     }
52     return true;
53 }
54 
ChangeToFile(std::string_view fileName)55 std::string ChangeToFile(std::string_view fileName)
56 {
57     std::ifstream file;
58     file.open(fileName);
59     std::stringstream ss;
60     ss << file.rdbuf();
61     std::string infos = ss.str();
62     file.close();
63     return infos;
64 }
65 
66 
GetIp(std::string ip)67 std::string GetIp(std::string ip)
68 {
69     return ip.substr(0, ip.length() - 1);
70 }
71 
72 } // namespace
73 class TlsSocketServerTest : public testing::Test {
74 public:
SetUpTestCase()75     static void SetUpTestCase() {}
76 
TearDownTestCase()77     static void TearDownTestCase() {}
78 
SetUp()79     virtual void SetUp() {}
80 
TearDown()81     virtual void TearDown() {}
82 };
83 
84 HWTEST_F(TlsSocketServerTest, ListenInterface, testing::ext::TestSize.Level2)
85 {
86     if (!CheckCaFileExistence("ListenInterface")) {
87         return;
88     }
89     TLSSocketServer server;
90     TlsSocket::TLSConnectOptions tlsListenOptions;
91 
__anon487d28ab0202(int32_t errCode) 92     server.Listen(tlsListenOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
93 }
94 
95 HWTEST_F(TlsSocketServerTest, sendInterface, testing::ext::TestSize.Level2)
96 {
97     if (!CheckCaFileExistence("sendInterface")) {
98         return;
99     }
100 
101     TLSSocketServer server;
102 
103     TLSServerSendOptions tlsServerSendOptions;
104 
105     const std::string data = "how do you do? this is sendInterface";
106     tlsServerSendOptions.SetSendData(data);
__anon487d28ab0302(int32_t errCode) 107     server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
108 }
109 
110 HWTEST_F(TlsSocketServerTest, closeInterface, testing::ext::TestSize.Level2)
111 {
112     if (!CheckCaFileExistence("closeInterface")) {
113         return;
114     }
115 
116     TLSSocketServer server;
117 
118     const std::string data = "how do you do? this is closeInterface";
119     TLSServerSendOptions tlsServerSendOptions;
120     tlsServerSendOptions.SetSendData(data);
121     int socketFd =  tlsServerSendOptions.GetSocket();
122 
__anon487d28ab0402(int32_t errCode) 123     server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
124     sleep(2);
125 
__anon487d28ab0502(int32_t errCode) 126     (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
127 }
128 
129 HWTEST_F(TlsSocketServerTest, stopInterface, testing::ext::TestSize.Level2)
130 {
131     if (!CheckCaFileExistence("stopInterface")) {
132         return;
133     }
134 
135     TLSSocketServer server;
136 
137     TLSServerSendOptions tlsServerSendOptions;
138     int socketFd =  tlsServerSendOptions.GetSocket();
139 
140 
141     const std::string data = "how do you do? this is stopInterface";
142     tlsServerSendOptions.SetSendData(data);
__anon487d28ab0602(int32_t errCode) 143     server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
144     sleep(2);
145 
146 
__anon487d28ab0702(int32_t errCode) 147     (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
148     sleep(2);
149 
150 
__anon487d28ab0802(int32_t errCode) 151     server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
152 }
153 
154 HWTEST_F(TlsSocketServerTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
155 {
156     if (!CheckCaFileExistence("getRemoteAddressInterface")) {
157         return;
158     }
159 
160     TLSSocketServer server;
161 
162     TLSServerSendOptions tlsServerSendOptions;
163     int socketFd = tlsServerSendOptions.GetSocket();
164     Socket::NetAddress address;
165 
166     address.SetAddress(GetIp(ChangeToFile(IP_ADDRESS)));
167     address.SetPort(std::atoi(ChangeToFile(PORT).c_str()));
168     address.SetFamilyBySaFamily(AF_INET);
169 
170     Socket::NetAddress netAddress;
171     server.GetRemoteAddress(socketFd, [&netAddress](int32_t errCode,
__anon487d28ab0902(int32_t errCode, const Socket::NetAddress &address) 172         const Socket::NetAddress &address) {
173     EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS);
174     netAddress.SetAddress(address.GetAddress());
175     netAddress.SetPort(address.GetPort());
176     netAddress.SetFamilyBySaFamily(address.GetSaFamily());
177     });
178 
179     const std::string data = "how do you do? this is getRemoteAddressInterface";
180     tlsServerSendOptions.SetSendData(data);
__anon487d28ab0a02(int32_t errCode) 181     server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
182     sleep(2);
183 
__anon487d28ab0b02(int32_t errCode) 184     (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
185     sleep(2);
186 
__anon487d28ab0c02(int32_t errCode) 187     server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
188 }
189 
190 HWTEST_F(TlsSocketServerTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
191 {
192     if (!CheckCaFileExistence("getRemoteCertificateInterface")) {
193         return;
194     }
195 
196     TLSSocketServer server;
197 
198     TLSServerSendOptions tlsServerSendOptions;
199     int socketFd = tlsServerSendOptions.GetSocket();
200 
201 
202     const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
203     tlsServerSendOptions.SetSendData(data);
__anon487d28ab0d02(int32_t errCode) 204     server.Send(tlsServerSendOptions, [](int32_t errCode) {
205         EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
206     sleep(2);
207 
__anon487d28ab0e02(int32_t errCode, const TlsSocket::X509CertRawData &cert) 208     server.GetRemoteCertificate(socketFd, [](int32_t errCode, const TlsSocket::X509CertRawData &cert) {
209         EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
210 
__anon487d28ab0f02(int32_t errCode) 211     (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
212     sleep(2);
213 
__anon487d28ab1002(int32_t errCode) 214     server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
215 }
216 
217 HWTEST_F(TlsSocketServerTest, getCertificateInterface, testing::ext::TestSize.Level2)
218 {
219     if (!CheckCaFileExistence("getCertificateInterface")) {
220         return;
221     }
222     TLSSocketServer server;
223 
224     const std::string data = "how do you do? This is UT test getCertificateInterface";
225     TLSServerSendOptions tlsServerSendOptions;
226     tlsServerSendOptions.SetSendData(data);
227     int socketFd = tlsServerSendOptions.GetSocket();
__anon487d28ab1102(int32_t errCode) 228     server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
229 
230     server.GetCertificate(
__anon487d28ab1202(int32_t errCode, const TlsSocket::X509CertRawData &cert) 231         [](int32_t errCode, const TlsSocket::X509CertRawData &cert) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
232 
233     sleep(2);
__anon487d28ab1302(int32_t errCode) 234     (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
235 }
236 
237 HWTEST_F(TlsSocketServerTest, protocolInterface, testing::ext::TestSize.Level2)
238 {
239     if (!CheckCaFileExistence("protocolInterface")) {
240         return;
241     }
242     TLSSocketServer server;
243 
244     const std::string data = "how do you do? this is protocolInterface";
245     TLSServerSendOptions tlsServerSendOptions;
246     tlsServerSendOptions.SetSendData(data);
247 
248     int socketFd = tlsServerSendOptions.GetSocket();
__anon487d28ab1402(int32_t errCode) 249     server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
250     std::string getProtocolVal;
__anon487d28ab1502(int32_t errCode, const std::string &protocol) 251     server.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
252         EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS);
253         getProtocolVal = protocol;
254     });
255     EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.3");
256 
257     Socket::SocketStateBase stateBase;
__anon487d28ab1602(int32_t errCode, Socket::SocketStateBase state) 258     server.GetState([&stateBase](int32_t errCode, Socket::SocketStateBase state) {
259         if (TlsSocket::TLSSOCKET_SUCCESS) {
260             EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS);
261             stateBase.SetIsBound(state.IsBound());
262             stateBase.SetIsClose(state.IsClose());
263             stateBase.SetIsConnected(state.IsConnected());
264         }
265     });
266     EXPECT_TRUE(stateBase.IsConnected());
267     sleep(2);
268 
__anon487d28ab1702(int32_t errCode) 269     (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
270 }
271 
272 HWTEST_F(TlsSocketServerTest, getSignatureAlgorithmsInterface, testing::ext::TestSize.Level2)
273 {
274     if (!CheckCaFileExistence("getSignatureAlgorithmsInterface")) {
275         return;
276     }
277 
278     TLSSocketServer server;
279     TlsSocket::TLSSecureOptions secureOption;
280 
281     const std::string data = "how do you do? this is getSigntureAlgorithmsInterface";
282     TLSServerSendOptions tlsServerSendOptions;
283     tlsServerSendOptions.SetSendData(data);
284 
285     int socketFd = tlsServerSendOptions.GetSocket();
__anon487d28ab1802(int32_t errCode) 286     server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
287     sleep(2);
288 
289     bool testFlag = false;
290     std::string signatureAlgorithmVec = {"rsa_pss_rsae_sha256:ECDSA+SHA256"};
291     secureOption.SetSignatureAlgorithms(signatureAlgorithmVec);
292     std::vector<std::string> testSignatureAlgorithms;
293     server.GetSignatureAlgorithms(socketFd, [&testSignatureAlgorithms](int32_t errCode,
__anon487d28ab1902(int32_t errCode, const std::vector<std::string> &algorithms) 294         const std::vector<std::string> &algorithms) {
295         if (errCode == TlsSocket::TLSSOCKET_SUCCESS) {
296             testSignatureAlgorithms = algorithms;
297         }
298     });
299     for (auto const &iter : testSignatureAlgorithms) {
300         if (iter == "ECDSA+SHA256") {
301             testFlag = true;
302         }
303     }
304     EXPECT_TRUE(testFlag);
305     sleep(2);
306 
307 
__anon487d28ab1a02(int32_t errCode) 308     (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
309 }
310 
311 
312 } //TlsSocketServer
313 } //NetStack
314 } //OHOS
315