1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <gtest/gtest.h>
17 #include <iostream>
18 
19 #include <openssl/ssl.h>
20 
21 #define private public
22 #include "accesstoken_kit.h"
23 #include "tls_socket.h"
24 #include "socket_remote_info.h"
25 #include "token_setproc.h"
26 #include "tls.h"
27 #include "TlsTest.h"
28 
29 namespace OHOS {
30 namespace NetStack {
31 namespace TlsSocket {
32 namespace {
33 using namespace testing::ext;
34 using namespace Security::AccessToken;
35 using Security::AccessToken::AccessTokenID;
36 static constexpr const char *KEY_PASS = "";
37 static constexpr const char *PROTOCOL12 = "TLSv1.2";
38 static constexpr const char *PROTOCOL13 = "TLSv1.3";
39 static constexpr const char *IP_ADDRESS = "127.0.0.1";
40 static constexpr const char *ALPN_PROTOCOL = "http/1.1";
41 static constexpr const char *SIGNATURE_ALGORITHM = "rsa_pss_rsae_sha256:ECDSA+SHA256";
42 static constexpr const char *CIPHER_SUITE = "AES256-SHA256";
43 static constexpr const char *SEND_DATA = "How do you do";
44 static constexpr const char *SEND_DATA_EMPTY = "";
45 static constexpr const size_t MAX_BUFFER_SIZE = 8192;
46 const int PORT = 7838;
47 const int SOCKET_FD = 5;
48 const int SSL_ERROR_RETURN = -1;
49 
BaseOption()50 TLSConnectOptions BaseOption()
51 {
52     (void)SEND_DATA;
53     (void)SEND_DATA_EMPTY;
54     (void)MAX_BUFFER_SIZE;
55     (void)SOCKET_FD;
56     (void)SSL_ERROR_RETURN;
57     TLSSecureOptions secureOption;
58     SecureData structureData(PRI_KEY_FILE);
59     secureOption.SetKey(structureData);
60     std::vector<std::string> caChain;
61     caChain.push_back(CA_CRT_FILE);
62     secureOption.SetCaChain(caChain);
63     secureOption.SetCert(CLIENT_FILE);
64     secureOption.SetCipherSuite(CIPHER_SUITE);
65     secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
66     std::vector<std::string> protocol;
67     protocol.push_back(PROTOCOL13);
68     secureOption.SetProtocolChain(protocol);
69 
70     TLSConnectOptions connectOptions;
71     connectOptions.SetTlsSecureOptions(secureOption);
72     Socket::NetAddress netAddress;
73     netAddress.SetAddress(IP_ADDRESS);
74     netAddress.SetPort(0);
75     netAddress.SetFamilyBySaFamily(AF_INET);
76     connectOptions.SetNetAddress(netAddress);
77     std::vector<std::string> alpnProtocols;
78     alpnProtocols.push_back(ALPN_PROTOCOL);
79     connectOptions.SetAlpnProtocols(alpnProtocols);
80     return connectOptions;
81 }
82 
83 HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest",
84                                .userID = 1,
85                                .instIndex = 0,
86                                .appIDDesc = "test",
87                                .isSystemApp = true};
88 
89 PermissionDef testPermDef = {
90     .permissionName = "ohos.permission.INTERNET",
91     .bundleName = "TlsSocketBranchTest",
92     .grantMode = 1,
93     .label = "label",
94     .labelId = 1,
95     .description = "Test Tls Socket Branch",
96     .descriptionId = 1,
97     .availableLevel = APL_SYSTEM_BASIC,
98 };
99 
100 PermissionStateFull testState = {
101     .grantFlags = {2},
102     .grantStatus = {PermissionState::PERMISSION_GRANTED},
103     .isGeneral = true,
104     .permissionName = "ohos.permission.INTERNET",
105     .resDeviceID = {"local"},
106 };
107 
108 HapPolicyParams testPolicyPrams = {
109     .apl = APL_SYSTEM_BASIC,
110     .domain = "test.domain",
111     .permList = {testPermDef},
112     .permStateList = {testState},
113 };
114 } // namespace
115 
116 class AccessToken {
117 public:
AccessToken()118     AccessToken() : currentID_(GetSelfTokenID())
119     {
120         AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams);
121         accessID_ = tokenIdEx.tokenIdExStruct.tokenID;
122         SetSelfTokenID(tokenIdEx.tokenIDEx);
123     }
~AccessToken()124     ~AccessToken()
125     {
126         AccessTokenKit::DeleteToken(accessID_);
127         SetSelfTokenID(currentID_);
128     }
129 
130 private:
131     AccessTokenID currentID_;
132     AccessTokenID accessID_ = 0;
133 };
134 
135 class TlsSocketBranchTest : public testing::Test {
136 public:
SetUpTestCase()137     static void SetUpTestCase() {}
138 
TearDownTestCase()139     static void TearDownTestCase() {}
140 
SetUp()141     virtual void SetUp() {}
142 
TearDown()143     virtual void TearDown() {}
144 };
145 
146 HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2)
147 {
148     TLSSecureOptions secureOption;
149     SecureData structureData(PRI_KEY_FILE);
150     secureOption.SetKey(structureData);
151 
152     SecureData keyPass(KEY_PASS);
153     secureOption.SetKeyPass(keyPass);
154     SecureData secureData = secureOption.GetKey();
155     EXPECT_EQ(structureData.Length(), strlen(PRI_KEY_FILE));
156     std::vector<std::string> caChain;
157     caChain.push_back(CA_CRT_FILE);
158     secureOption.SetCaChain(caChain);
159     std::vector<std::string> getCaChain = secureOption.GetCaChain();
160     EXPECT_NE(getCaChain.data(), nullptr);
161 
162     secureOption.SetCert(CLIENT_FILE);
163     std::string getCert = secureOption.GetCert();
164     EXPECT_NE(getCert.data(), nullptr);
165 
166     std::vector<std::string> protocolVec = {PROTOCOL12, PROTOCOL13};
167     secureOption.SetProtocolChain(protocolVec);
168     std::vector<std::string> getProtocol;
169     getProtocol = secureOption.GetProtocolChain();
170 
171     TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
172     TLSSecureOptions equalOption = secureOption;
173 }
174 
175 HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2)
176 {
177     TLSSecureOptions secureOption;
178     secureOption.SetUseRemoteCipherPrefer(false);
179     bool isUseRemoteCipher = secureOption.UseRemoteCipherPrefer();
180     EXPECT_FALSE(isUseRemoteCipher);
181 
182     secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
183     std::string getSignatureAlgorithm = secureOption.GetSignatureAlgorithms();
184     EXPECT_STREQ(getSignatureAlgorithm.data(), SIGNATURE_ALGORITHM);
185 
186     secureOption.SetCipherSuite(CIPHER_SUITE);
187     std::string getCipherSuite = secureOption.GetCipherSuite();
188     EXPECT_STREQ(getCipherSuite.data(), CIPHER_SUITE);
189 
190     TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
191     TLSSecureOptions equalOption = secureOption;
192 
193     TLSConnectOptions connectOptions;
194     connectOptions.SetTlsSecureOptions(secureOption);
195 }
196 
197 HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2)
198 {
199     TLSSecureOptions secureOption;
200     TLSConnectOptions connectOptions;
201     connectOptions.SetTlsSecureOptions(secureOption);
202 
203     Socket::NetAddress netAddress;
204     netAddress.SetAddress(IP_ADDRESS);
205     netAddress.SetPort(PORT);
206     connectOptions.SetNetAddress(netAddress);
207     Socket::NetAddress getNetAddress = connectOptions.GetNetAddress();
208     std::string address = getNetAddress.GetAddress();
209     EXPECT_STREQ(IP_ADDRESS, address.data());
210     int port = getNetAddress.GetPort();
211     EXPECT_EQ(port, PORT);
212     netAddress.SetFamilyBySaFamily(AF_INET6);
213     sa_family_t getFamily = netAddress.GetSaFamily();
214     EXPECT_EQ(getFamily, AF_INET6);
215 
216     std::vector<std::string> alpnProtocols;
217     alpnProtocols.push_back(ALPN_PROTOCOL);
218     connectOptions.SetAlpnProtocols(alpnProtocols);
219     std::vector<std::string> getAlpnProtocols;
220     getAlpnProtocols = connectOptions.GetAlpnProtocols();
221     EXPECT_STREQ(getAlpnProtocols[0].data(), alpnProtocols[0].data());
222 }
223 
224 HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2)
225 {
226     TLSSecureOptions secureOption;
227     SecureData structureData(PRI_KEY_FILE);
228     secureOption.SetKey(structureData);
229     std::vector<std::string> caChain;
230     caChain.push_back(CA_CRT_FILE);
231     secureOption.SetCaChain(caChain);
232     secureOption.SetCert(CLIENT_FILE);
233 
234     TLSConnectOptions connectOptions;
235     connectOptions.SetTlsSecureOptions(secureOption);
236 
237     Socket::NetAddress netAddress;
238     netAddress.SetAddress(IP_ADDRESS);
239     netAddress.SetPort(0);
240     netAddress.SetFamilyBySaFamily(AF_INET);
241     EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
242 }
243 
244 HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2)
245 {
246     TLSConnectOptions tlsConnectOptions = BaseOption();
247 
248     AccessToken token;
249     TLSSocket tlsSocket;
250     tlsSocket.OnError(
__anonc01c89aa0202(int32_t errorNumber, const std::string &errorString) 251         [](int32_t errorNumber, const std::string &errorString) { EXPECT_NE(TLSSOCKET_SUCCESS, errorNumber); });
__anonc01c89aa0302(int32_t errCode) 252     tlsSocket.Connect(tlsConnectOptions, [](int32_t errCode) { EXPECT_NE(TLSSOCKET_SUCCESS, errCode); });
253     std::string getData;
__anonc01c89aa0402(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) 254     tlsSocket.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
255         EXPECT_STREQ(getData.data(), nullptr);
256     });
257     const std::string data = "how do you do?";
258     Socket::TCPSendOptions tcpSendOptions;
259     tcpSendOptions.SetData(data);
__anonc01c89aa0502(int32_t errCode) 260     tlsSocket.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
261     tlsSocket.GetSignatureAlgorithms(
__anonc01c89aa0602(int32_t errCode, const std::vector<std::string> &algorithms) 262         [](int32_t errCode, const std::vector<std::string> &algorithms) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
263     tlsSocket.GetCertificate(
__anonc01c89aa0702(int32_t errCode, const X509CertRawData &cert) 264         [](int32_t errCode, const X509CertRawData &cert) { EXPECT_NE(errCode, TLSSOCKET_SUCCESS); });
265     tlsSocket.GetCipherSuite(
__anonc01c89aa0802(int32_t errCode, const std::vector<std::string> &suite) 266         [](int32_t errCode, const std::vector<std::string> &suite) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
__anonc01c89aa0902(int32_t errCode, const std::string &protocol) 267     tlsSocket.GetProtocol([](int32_t errCode, const std::string &protocol) { EXPECT_EQ(errCode, TLSSOCKET_SUCCESS); });
268     tlsSocket.GetRemoteCertificate(
__anonc01c89aa0a02(int32_t errCode, const X509CertRawData &cert) 269         [](int32_t errCode, const X509CertRawData &cert) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
__anonc01c89aa0b02(int32_t errCode) 270     (void)tlsSocket.Close([](int32_t errCode) { EXPECT_FALSE(errCode == TLSSOCKET_SUCCESS); });
271 }
272 
273 HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2)
274 {
275     TLSSocket tlsSocket;
276     TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
277 
278     std::vector<std::string> alpnProtocols;
279     alpnProtocols.push_back(ALPN_PROTOCOL);
280     bool alpnProSslNull = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
281     EXPECT_FALSE(alpnProSslNull);
282     std::vector<std::string> getCipherSuite = tlsSocketInternal->GetCipherSuite();
283     EXPECT_EQ(getCipherSuite.size(), 0);
284     bool setSharedSigals = tlsSocketInternal->SetSharedSigals();
285     EXPECT_FALSE(setSharedSigals);
286     tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
287     getCipherSuite = tlsSocketInternal->GetCipherSuite();
288     EXPECT_NE(getCipherSuite.size(), 0);
289     setSharedSigals = tlsSocketInternal->SetSharedSigals();
290     EXPECT_FALSE(setSharedSigals);
291     TLSConnectOptions connectOptions = BaseOption();
292     bool alpnPro = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
293     EXPECT_TRUE(alpnPro);
294 
295     Socket::SocketRemoteInfo remoteInfo;
296     tlsSocketInternal->hostName_ = IP_ADDRESS;
297     tlsSocketInternal->port_ = PORT;
298     tlsSocketInternal->family_ = AF_INET;
299     tlsSocketInternal->MakeRemoteInfo(remoteInfo);
300     getCipherSuite = tlsSocketInternal->GetCipherSuite();
301     EXPECT_NE(getCipherSuite.size(), 0);
302 
303     std::string getRemoteCert = tlsSocketInternal->GetRemoteCertificate();
304     EXPECT_EQ(getRemoteCert, "");
305 
306     std::vector<std::string> getSignatureAlgorithms = tlsSocketInternal->GetSignatureAlgorithms();
307     EXPECT_EQ(getSignatureAlgorithms.size(), 0);
308 
309     std::string getProtocol = tlsSocketInternal->GetProtocol();
310     EXPECT_NE(getProtocol, "");
311 
312     setSharedSigals = tlsSocketInternal->SetSharedSigals();
313     EXPECT_FALSE(setSharedSigals);
314 
315     ssl_st *ssl = tlsSocketInternal->GetSSL();
316     EXPECT_NE(ssl, nullptr);
317     delete tlsSocketInternal;
318 }
319 } // namespace TlsSocket
320 } // namespace NetStack
321 } // namespace OHOS
322