1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <gtest/gtest.h>
17 #include <iostream>
18
19 #include <openssl/ssl.h>
20
21 #define private public
22 #include "accesstoken_kit.h"
23 #include "tls_socket.h"
24 #include "socket_remote_info.h"
25 #include "token_setproc.h"
26 #include "tls.h"
27 #include "TlsTest.h"
28
29 namespace OHOS {
30 namespace NetStack {
31 namespace TlsSocket {
32 namespace {
33 using namespace testing::ext;
34 using namespace Security::AccessToken;
35 using Security::AccessToken::AccessTokenID;
36 static constexpr const char *KEY_PASS = "";
37 static constexpr const char *PROTOCOL12 = "TLSv1.2";
38 static constexpr const char *PROTOCOL13 = "TLSv1.3";
39 static constexpr const char *IP_ADDRESS = "127.0.0.1";
40 static constexpr const char *ALPN_PROTOCOL = "http/1.1";
41 static constexpr const char *SIGNATURE_ALGORITHM = "rsa_pss_rsae_sha256:ECDSA+SHA256";
42 static constexpr const char *CIPHER_SUITE = "AES256-SHA256";
43 static constexpr const char *SEND_DATA = "How do you do";
44 static constexpr const char *SEND_DATA_EMPTY = "";
45 static constexpr const size_t MAX_BUFFER_SIZE = 8192;
46 const int PORT = 7838;
47 const int SOCKET_FD = 5;
48 const int SSL_ERROR_RETURN = -1;
49
BaseOption()50 TLSConnectOptions BaseOption()
51 {
52 (void)SEND_DATA;
53 (void)SEND_DATA_EMPTY;
54 (void)MAX_BUFFER_SIZE;
55 (void)SOCKET_FD;
56 (void)SSL_ERROR_RETURN;
57 TLSSecureOptions secureOption;
58 SecureData structureData(PRI_KEY_FILE);
59 secureOption.SetKey(structureData);
60 std::vector<std::string> caChain;
61 caChain.push_back(CA_CRT_FILE);
62 secureOption.SetCaChain(caChain);
63 secureOption.SetCert(CLIENT_FILE);
64 secureOption.SetCipherSuite(CIPHER_SUITE);
65 secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
66 std::vector<std::string> protocol;
67 protocol.push_back(PROTOCOL13);
68 secureOption.SetProtocolChain(protocol);
69
70 TLSConnectOptions connectOptions;
71 connectOptions.SetTlsSecureOptions(secureOption);
72 Socket::NetAddress netAddress;
73 netAddress.SetAddress(IP_ADDRESS);
74 netAddress.SetPort(0);
75 netAddress.SetFamilyBySaFamily(AF_INET);
76 connectOptions.SetNetAddress(netAddress);
77 std::vector<std::string> alpnProtocols;
78 alpnProtocols.push_back(ALPN_PROTOCOL);
79 connectOptions.SetAlpnProtocols(alpnProtocols);
80 return connectOptions;
81 }
82
83 HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest",
84 .userID = 1,
85 .instIndex = 0,
86 .appIDDesc = "test",
87 .isSystemApp = true};
88
89 PermissionDef testPermDef = {
90 .permissionName = "ohos.permission.INTERNET",
91 .bundleName = "TlsSocketBranchTest",
92 .grantMode = 1,
93 .label = "label",
94 .labelId = 1,
95 .description = "Test Tls Socket Branch",
96 .descriptionId = 1,
97 .availableLevel = APL_SYSTEM_BASIC,
98 };
99
100 PermissionStateFull testState = {
101 .grantFlags = {2},
102 .grantStatus = {PermissionState::PERMISSION_GRANTED},
103 .isGeneral = true,
104 .permissionName = "ohos.permission.INTERNET",
105 .resDeviceID = {"local"},
106 };
107
108 HapPolicyParams testPolicyPrams = {
109 .apl = APL_SYSTEM_BASIC,
110 .domain = "test.domain",
111 .permList = {testPermDef},
112 .permStateList = {testState},
113 };
114 } // namespace
115
116 class AccessToken {
117 public:
AccessToken()118 AccessToken() : currentID_(GetSelfTokenID())
119 {
120 AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams);
121 accessID_ = tokenIdEx.tokenIdExStruct.tokenID;
122 SetSelfTokenID(tokenIdEx.tokenIDEx);
123 }
~AccessToken()124 ~AccessToken()
125 {
126 AccessTokenKit::DeleteToken(accessID_);
127 SetSelfTokenID(currentID_);
128 }
129
130 private:
131 AccessTokenID currentID_;
132 AccessTokenID accessID_ = 0;
133 };
134
135 class TlsSocketBranchTest : public testing::Test {
136 public:
SetUpTestCase()137 static void SetUpTestCase() {}
138
TearDownTestCase()139 static void TearDownTestCase() {}
140
SetUp()141 virtual void SetUp() {}
142
TearDown()143 virtual void TearDown() {}
144 };
145
146 HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2)
147 {
148 TLSSecureOptions secureOption;
149 SecureData structureData(PRI_KEY_FILE);
150 secureOption.SetKey(structureData);
151
152 SecureData keyPass(KEY_PASS);
153 secureOption.SetKeyPass(keyPass);
154 SecureData secureData = secureOption.GetKey();
155 EXPECT_EQ(structureData.Length(), strlen(PRI_KEY_FILE));
156 std::vector<std::string> caChain;
157 caChain.push_back(CA_CRT_FILE);
158 secureOption.SetCaChain(caChain);
159 std::vector<std::string> getCaChain = secureOption.GetCaChain();
160 EXPECT_NE(getCaChain.data(), nullptr);
161
162 secureOption.SetCert(CLIENT_FILE);
163 std::string getCert = secureOption.GetCert();
164 EXPECT_NE(getCert.data(), nullptr);
165
166 std::vector<std::string> protocolVec = {PROTOCOL12, PROTOCOL13};
167 secureOption.SetProtocolChain(protocolVec);
168 std::vector<std::string> getProtocol;
169 getProtocol = secureOption.GetProtocolChain();
170
171 TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
172 TLSSecureOptions equalOption = secureOption;
173 }
174
175 HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2)
176 {
177 TLSSecureOptions secureOption;
178 secureOption.SetUseRemoteCipherPrefer(false);
179 bool isUseRemoteCipher = secureOption.UseRemoteCipherPrefer();
180 EXPECT_FALSE(isUseRemoteCipher);
181
182 secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
183 std::string getSignatureAlgorithm = secureOption.GetSignatureAlgorithms();
184 EXPECT_STREQ(getSignatureAlgorithm.data(), SIGNATURE_ALGORITHM);
185
186 secureOption.SetCipherSuite(CIPHER_SUITE);
187 std::string getCipherSuite = secureOption.GetCipherSuite();
188 EXPECT_STREQ(getCipherSuite.data(), CIPHER_SUITE);
189
190 TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
191 TLSSecureOptions equalOption = secureOption;
192
193 TLSConnectOptions connectOptions;
194 connectOptions.SetTlsSecureOptions(secureOption);
195 }
196
197 HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2)
198 {
199 TLSSecureOptions secureOption;
200 TLSConnectOptions connectOptions;
201 connectOptions.SetTlsSecureOptions(secureOption);
202
203 Socket::NetAddress netAddress;
204 netAddress.SetAddress(IP_ADDRESS);
205 netAddress.SetPort(PORT);
206 connectOptions.SetNetAddress(netAddress);
207 Socket::NetAddress getNetAddress = connectOptions.GetNetAddress();
208 std::string address = getNetAddress.GetAddress();
209 EXPECT_STREQ(IP_ADDRESS, address.data());
210 int port = getNetAddress.GetPort();
211 EXPECT_EQ(port, PORT);
212 netAddress.SetFamilyBySaFamily(AF_INET6);
213 sa_family_t getFamily = netAddress.GetSaFamily();
214 EXPECT_EQ(getFamily, AF_INET6);
215
216 std::vector<std::string> alpnProtocols;
217 alpnProtocols.push_back(ALPN_PROTOCOL);
218 connectOptions.SetAlpnProtocols(alpnProtocols);
219 std::vector<std::string> getAlpnProtocols;
220 getAlpnProtocols = connectOptions.GetAlpnProtocols();
221 EXPECT_STREQ(getAlpnProtocols[0].data(), alpnProtocols[0].data());
222 }
223
224 HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2)
225 {
226 TLSSecureOptions secureOption;
227 SecureData structureData(PRI_KEY_FILE);
228 secureOption.SetKey(structureData);
229 std::vector<std::string> caChain;
230 caChain.push_back(CA_CRT_FILE);
231 secureOption.SetCaChain(caChain);
232 secureOption.SetCert(CLIENT_FILE);
233
234 TLSConnectOptions connectOptions;
235 connectOptions.SetTlsSecureOptions(secureOption);
236
237 Socket::NetAddress netAddress;
238 netAddress.SetAddress(IP_ADDRESS);
239 netAddress.SetPort(0);
240 netAddress.SetFamilyBySaFamily(AF_INET);
241 EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
242 }
243
244 HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2)
245 {
246 TLSConnectOptions tlsConnectOptions = BaseOption();
247
248 AccessToken token;
249 TLSSocket tlsSocket;
250 tlsSocket.OnError(
__anonc01c89aa0202(int32_t errorNumber, const std::string &errorString) 251 [](int32_t errorNumber, const std::string &errorString) { EXPECT_NE(TLSSOCKET_SUCCESS, errorNumber); });
__anonc01c89aa0302(int32_t errCode) 252 tlsSocket.Connect(tlsConnectOptions, [](int32_t errCode) { EXPECT_NE(TLSSOCKET_SUCCESS, errCode); });
253 std::string getData;
__anonc01c89aa0402(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) 254 tlsSocket.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
255 EXPECT_STREQ(getData.data(), nullptr);
256 });
257 const std::string data = "how do you do?";
258 Socket::TCPSendOptions tcpSendOptions;
259 tcpSendOptions.SetData(data);
__anonc01c89aa0502(int32_t errCode) 260 tlsSocket.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
261 tlsSocket.GetSignatureAlgorithms(
__anonc01c89aa0602(int32_t errCode, const std::vector<std::string> &algorithms) 262 [](int32_t errCode, const std::vector<std::string> &algorithms) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
263 tlsSocket.GetCertificate(
__anonc01c89aa0702(int32_t errCode, const X509CertRawData &cert) 264 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_NE(errCode, TLSSOCKET_SUCCESS); });
265 tlsSocket.GetCipherSuite(
__anonc01c89aa0802(int32_t errCode, const std::vector<std::string> &suite) 266 [](int32_t errCode, const std::vector<std::string> &suite) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
__anonc01c89aa0902(int32_t errCode, const std::string &protocol) 267 tlsSocket.GetProtocol([](int32_t errCode, const std::string &protocol) { EXPECT_EQ(errCode, TLSSOCKET_SUCCESS); });
268 tlsSocket.GetRemoteCertificate(
__anonc01c89aa0a02(int32_t errCode, const X509CertRawData &cert) 269 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
__anonc01c89aa0b02(int32_t errCode) 270 (void)tlsSocket.Close([](int32_t errCode) { EXPECT_FALSE(errCode == TLSSOCKET_SUCCESS); });
271 }
272
273 HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2)
274 {
275 TLSSocket tlsSocket;
276 TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
277
278 std::vector<std::string> alpnProtocols;
279 alpnProtocols.push_back(ALPN_PROTOCOL);
280 bool alpnProSslNull = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
281 EXPECT_FALSE(alpnProSslNull);
282 std::vector<std::string> getCipherSuite = tlsSocketInternal->GetCipherSuite();
283 EXPECT_EQ(getCipherSuite.size(), 0);
284 bool setSharedSigals = tlsSocketInternal->SetSharedSigals();
285 EXPECT_FALSE(setSharedSigals);
286 tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
287 getCipherSuite = tlsSocketInternal->GetCipherSuite();
288 EXPECT_NE(getCipherSuite.size(), 0);
289 setSharedSigals = tlsSocketInternal->SetSharedSigals();
290 EXPECT_FALSE(setSharedSigals);
291 TLSConnectOptions connectOptions = BaseOption();
292 bool alpnPro = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
293 EXPECT_TRUE(alpnPro);
294
295 Socket::SocketRemoteInfo remoteInfo;
296 tlsSocketInternal->hostName_ = IP_ADDRESS;
297 tlsSocketInternal->port_ = PORT;
298 tlsSocketInternal->family_ = AF_INET;
299 tlsSocketInternal->MakeRemoteInfo(remoteInfo);
300 getCipherSuite = tlsSocketInternal->GetCipherSuite();
301 EXPECT_NE(getCipherSuite.size(), 0);
302
303 std::string getRemoteCert = tlsSocketInternal->GetRemoteCertificate();
304 EXPECT_EQ(getRemoteCert, "");
305
306 std::vector<std::string> getSignatureAlgorithms = tlsSocketInternal->GetSignatureAlgorithms();
307 EXPECT_EQ(getSignatureAlgorithms.size(), 0);
308
309 std::string getProtocol = tlsSocketInternal->GetProtocol();
310 EXPECT_NE(getProtocol, "");
311
312 setSharedSigals = tlsSocketInternal->SetSharedSigals();
313 EXPECT_FALSE(setSharedSigals);
314
315 ssl_st *ssl = tlsSocketInternal->GetSSL();
316 EXPECT_NE(ssl, nullptr);
317 delete tlsSocketInternal;
318 }
319 } // namespace TlsSocket
320 } // namespace NetStack
321 } // namespace OHOS
322