1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "net_address.h"
17 #include "secure_data.h"
18 #include "socket_error.h"
19 #include "socket_state_base.h"
20 #include "tls.h"
21 #include "tls_certificate.h"
22 #include "tls_configuration.h"
23 #include "tls_key.h"
24 #include "tls_socket.h"
25 #include "tls_utils_test.h"
26 
27 namespace OHOS {
28 namespace NetStack {
29 namespace TlsSocket {
MockCertChainNetAddress(Socket::NetAddress & address)30 void MockCertChainNetAddress(Socket::NetAddress &address)
31 {
32     address.SetAddress(TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)));
33     address.SetPort(std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
34     address.SetFamilyBySaFamily(AF_INET);
35 }
36 
MockCertChainParamOptions(Socket::NetAddress & address,TLSSecureOptions & secureOption,TLSConnectOptions & options)37 void MockCertChainParamOptions(Socket::NetAddress &address, TLSSecureOptions &secureOption, TLSConnectOptions &options)
38 {
39     secureOption.SetKey(SecureData(TlsUtilsTest::ChangeToFile(PRIVATE_KEY_PEM_CHAIN)));
40     secureOption.SetCert(TlsUtilsTest::ChangeToFile(CLIENT_CRT_CHAIN));
41 
42     MockCertChainNetAddress(address);
43     options.SetNetAddress(address);
44     options.SetTlsSecureOptions(secureOption);
45 }
46 
SetCertChainHwTestShortParam(TLSSocket & server)47 void SetCertChainHwTestShortParam(TLSSocket &server)
48 {
49     TLSConnectOptions options;
50     TLSSecureOptions secureOption;
51     std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
52         TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
53     secureOption.SetCaChain(caVec);
54     Socket::NetAddress address;
55     MockCertChainParamOptions(address, secureOption, options);
56 
57     server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
58     server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
59 }
60 
SetCertChainHwTestLongParam(TLSSocket & server)61 void SetCertChainHwTestLongParam(TLSSocket &server)
62 {
63     Socket::NetAddress address;
64     std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
65         TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
66     TLSSecureOptions secureOption;
67     secureOption.SetCaChain(caVec);
68     std::string protocolV13 = "TLSv1.3";
69     std::vector<std::string> protocolVec = { protocolV13 };
70     secureOption.SetProtocolChain(protocolVec);
71     secureOption.SetCipherSuite("AES256-SHA256");
72 
73     TLSConnectOptions options;
74     MockCertChainParamOptions(address, secureOption, options);
75 
76     server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
77     server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
78 }
79 
80 HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
81 {
82     if (!TlsUtilsTest::CheckCaPathChainExistence("bindInterface")) {
83         return;
84     }
85 
86     TLSSocket testServer;
87     Socket::NetAddress address;
88     MockCertChainNetAddress(address);
__anon6010aead0502(int32_t errCode) 89     testServer.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
90 }
91 
92 HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
93 {
94     if (!TlsUtilsTest::CheckCaPathChainExistence("connectInterface")) {
95         return;
96     }
97     TLSSocket certChainServer;
98     SetCertChainHwTestShortParam(certChainServer);
99 
100     const std::string data = "how do you do? this is connectInterface";
101     Socket::TCPSendOptions tcpSendOptions;
102     tcpSendOptions.SetData(data);
__anon6010aead0602(int32_t errCode) 103     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
104     sleep(2);
105 
__anon6010aead0702(int32_t errCode) 106     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
107     sleep(2);
108 }
109 
110 HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
111 {
112     if (!TlsUtilsTest::CheckCaPathChainExistence("closeInterface")) {
113         return;
114     }
115     TLSSocket certChainServer;
116     SetCertChainHwTestShortParam(certChainServer);
117 
118     const std::string data = "how do you do? this is closeInterface";
119     Socket::TCPSendOptions tcpSendOptions;
120     tcpSendOptions.SetData(data);
121 
__anon6010aead0802(int32_t errCode) 122     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
123     sleep(2);
124 
__anon6010aead0902(int32_t errCode) 125     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
126 }
127 
128 HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
129 {
130     if (!TlsUtilsTest::CheckCaPathChainExistence("sendInterface")) {
131         return;
132     }
133     TLSSocket certChainServer;
134     SetCertChainHwTestShortParam(certChainServer);
135 
136     const std::string data = "how do you do? this is sendInterface";
137     Socket::TCPSendOptions tcpSendOptions;
138     tcpSendOptions.SetData(data);
139 
__anon6010aead0a02(int32_t errCode) 140     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
141     sleep(2);
142 
__anon6010aead0b02(int32_t errCode) 143     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
144 }
145 
146 HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
147 {
148     if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
149         return;
150     }
151     TLSSocket certChainServer;
152     TLSConnectOptions options;
153     TLSSecureOptions secureOption;
154     Socket::NetAddress address;
155     std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
156         TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
157     secureOption.SetCaChain(caVec);
158     MockCertChainParamOptions(address, secureOption, options);
159 
__anon6010aead0c02(int32_t errCode) 160     certChainServer.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
__anon6010aead0d02(int32_t errCode) 161     certChainServer.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
162 
163     Socket::NetAddress netAddress;
__anon6010aead0e02(int32_t errCode, const Socket::NetAddress &address) 164     certChainServer.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
165         EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
166         netAddress.SetPort(address.GetPort());
167         netAddress.SetFamilyBySaFamily(address.GetSaFamily());
168         netAddress.SetAddress(address.GetAddress());
169     });
170     EXPECT_STREQ(netAddress.GetAddress().c_str(), TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)).c_str());
171     EXPECT_EQ(address.GetPort(), std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
172     EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
173 
174     const std::string data = "how do you do? this is getRemoteAddressInterface";
175     Socket::TCPSendOptions tcpSendOptions;
176     tcpSendOptions.SetData(data);
177 
__anon6010aead0f02(int32_t errCode) 178     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
179 
__anon6010aead1002(int32_t errCode) 180     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
181 }
182 
183 HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
184 {
185     if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
186         return;
187     }
188     TLSSocket certChainServer;
189     SetCertChainHwTestShortParam(certChainServer);
190 
191     Socket::SocketStateBase TlsSocketstate;
__anon6010aead1102(int32_t errCode, const Socket::SocketStateBase &state) 192     certChainServer.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
193         EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
194         TlsSocketstate = state;
195     });
196     std::cout << "TlsSocketCertChainTest TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
197     EXPECT_TRUE(TlsSocketstate.IsBound());
198     EXPECT_TRUE(!TlsSocketstate.IsClose());
199     EXPECT_TRUE(TlsSocketstate.IsConnected());
200 
201     const std::string tlsSocketCertChainTestData = "how do you do? this is getStateInterface";
202     Socket::TCPSendOptions tcpSendOptions;
203     tcpSendOptions.SetData(tlsSocketCertChainTestData);
__anon6010aead1202(int32_t errCode) 204     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
205 
206     sleep(2);
207 
__anon6010aead1302(int32_t errCode) 208     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
209 }
210 
211 HWTEST_F(TlsSocketTest, getCertificateInterface, testing::ext::TestSize.Level2)
212 {
213     if (!TlsUtilsTest::CheckCaPathChainExistence("getCertificateInterface")) {
214         return;
215     }
216     TLSSocket certChainServer;
217     SetCertChainHwTestShortParam(certChainServer);
218     Socket::TCPSendOptions tcpSendOptions;
219     const std::string data = "how do you do? This is UT test getCertificateInterface";
220 
221     tcpSendOptions.SetData(data);
__anon6010aead1402(int32_t errCode) 222     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
223 
224     certChainServer.GetCertificate(
__anon6010aead1502(int32_t errCode, const X509CertRawData &cert) 225         [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
226 
227     sleep(2);
__anon6010aead1602(int32_t errCode) 228     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
229 }
230 
231 HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
232 {
233     if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteCertificateInterface")) {
234         return;
235     }
236     TLSSocket certChainServer;
237     SetCertChainHwTestShortParam(certChainServer);
238     Socket::TCPSendOptions tcpSendOptions;
239     const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
240     tcpSendOptions.SetData(data);
241 
__anon6010aead1702(int32_t errCode) 242     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
243 
244     certChainServer.GetRemoteCertificate(
__anon6010aead1802(int32_t errCode, const X509CertRawData &cert) 245         [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
246 
247     sleep(2);
__anon6010aead1902(int32_t errCode) 248     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
249 }
250 
251 HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
252 {
253     if (!TlsUtilsTest::CheckCaPathChainExistence("protocolInterface")) {
254         return;
255     }
256     TLSSocket certChainServer;
257     SetCertChainHwTestLongParam(certChainServer);
258 
259     const std::string data = "how do you do? this is protocolInterface.";
260     Socket::TCPSendOptions tcpSendOptions;
261     tcpSendOptions.SetData(data);
262 
__anon6010aead1a02(int32_t errCode) 263     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
264     std::string getProtocolVal = "";
__anon6010aead1b02(int32_t errCode, const std::string &protocol) 265     certChainServer.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
266         EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
267         getProtocolVal = protocol;
268     });
269     EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.3");
270     sleep(2);
271 
__anon6010aead1c02(int32_t errCode) 272     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
273 }
274 
275 HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
276 {
277     if (!TlsUtilsTest::CheckCaPathChainExistence("getCipherSuiteInterface")) {
278         return;
279     }
280     TLSSocket certChainServer;
281     SetCertChainHwTestLongParam(certChainServer);
282 
283     bool successFlag = false;
284     const std::string data = "how do you do? This is getCipherSuiteInterface";
285     Socket::TCPSendOptions testTcpSendOptions;
286     testTcpSendOptions.SetData(data);
__anon6010aead1d02(int32_t errCode) 287     certChainServer.Send(testTcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
288 
289     std::vector<std::string> testCipherSuite;
__anon6010aead1e02(int32_t errCode, const std::vector<std::string> &suite) 290     certChainServer.GetCipherSuite([&testCipherSuite](int32_t errCode, const std::vector<std::string> &suite) {
291         EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
292         testCipherSuite = suite;
293     });
294 
295     for (auto const &iter : testCipherSuite) {
296         if (iter == "AES256-SHA256") {
297             successFlag = true;
298         }
299     }
300 
301     EXPECT_TRUE(successFlag);
302     sleep(2);
303 
__anon6010aead1f02(int32_t errCode) 304     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
305 }
306 
307 HWTEST_F(TlsSocketTest, getSignatureAlgorithmsInterface, testing::ext::TestSize.Level2)
308 {
309     if (!TlsUtilsTest::CheckCaPathChainExistence("getSignatureAlgorithmsInterface")) {
310         return;
311     }
312 
313     TLSSocket certChainServer;
314     TLSSecureOptions secureOption;
315     std::string signatureAlgorithmVec = {"rsa_pss_rsae_sha256:ECDSA+SHA256"};
316     secureOption.SetSignatureAlgorithms(signatureAlgorithmVec);
317     std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
318         TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
319     secureOption.SetCaChain(caVec);
320     std::string protocolV13 = "TLSv1.3";
321     std::vector<std::string> protocolVec = {protocolV13};
322     secureOption.SetProtocolChain(protocolVec);
323     Socket::NetAddress address;
324     TLSConnectOptions options;
325     MockCertChainParamOptions(address, secureOption, options);
326 
327     bool successFlag = false;
__anon6010aead2002(int32_t errCode) 328     certChainServer.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
__anon6010aead2102(int32_t errCode) 329     certChainServer.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
330 
331     const std::string data = "how do you do? this is getSignatureAlgorithmsInterface";
332     Socket::TCPSendOptions testOptions;
333     testOptions.SetData(data);
__anon6010aead2202(int32_t errCode) 334     certChainServer.Send(testOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
335 
336     std::vector<std::string> testSignatureAlgorithms;
337     certChainServer.GetSignatureAlgorithms(
__anon6010aead2302(int32_t errCode, const std::vector<std::string> &algorithms) 338         [&testSignatureAlgorithms](int32_t errCode, const std::vector<std::string> &algorithms) {
339             if (errCode == TLSSOCKET_SUCCESS) {
340                 testSignatureAlgorithms = algorithms;
341             }
342         });
343 
344     for (auto const &iter : testSignatureAlgorithms) {
345         if (iter == "ECDSA+SHA256") {
346             successFlag = true;
347         }
348     }
349     EXPECT_TRUE(successFlag);
350     sleep(2);
__anon6010aead2402(int32_t errCode) 351     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
352 }
353 
354 HWTEST_F(TlsSocketTest, onMessageDataInterface, testing::ext::TestSize.Level2)
355 {
356     if (!TlsUtilsTest::CheckCaPathChainExistence("tlsSocketOnMessageData")) {
357         return;
358     }
359     std::string getData = "server->client";
360     TLSSocket certChainServer;
361     SetCertChainHwTestLongParam(certChainServer);
__anon6010aead2502(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) 362     certChainServer.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
363         if (data == getData) {
364             EXPECT_TRUE(true);
365         } else {
366             EXPECT_TRUE(false);
367         }
368     });
369 
370     const std::string data = "how do you do? this is tlsSocketOnMessageData";
371     Socket::TCPSendOptions tcpSendOptions;
372     tcpSendOptions.SetData(data);
__anon6010aead2602(int32_t errCode) 373     certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
374 
375     sleep(2);
__anon6010aead2702(int32_t errCode) 376     (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
377 }
378 } // namespace TlsSocket
379 } // namespace NetStack
380 } // namespace OHOS
381