1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <fstream>
17 #include <gtest/gtest.h>
18 #include <iostream>
19 
20 #ifdef GTEST_API_
21 #define private public
22 #endif
23 
24 #include "net_address.h"
25 #include "netstack_log.h"
26 #include "secure_data.h"
27 #include "socket_error.h"
28 #include "socket_state_base.h"
29 #include "tls.h"
30 #include "tls_certificate.h"
31 #include "tls_configuration.h"
32 #include "tls_key.h"
33 #include "tls_socket.h"
34 #include "tls_socket_server.h"
35 
36 namespace OHOS {
37 namespace NetStack {
38 namespace TlsSocket {
39 
40 class TLSSecureOptionsBranchTest : public testing::Test {
41 public:
SetUpTestCase()42     static void SetUpTestCase() {}
43 
TearDownTestCase()44     static void TearDownTestCase() {}
45 
SetUp()46     virtual void SetUp() {}
47 
TearDown()48     virtual void TearDown() {}
49 };
50 
51 HWTEST_F(TLSSecureOptionsBranchTest, TLSSecureOptionsBranchTest001, testing::ext::TestSize.Level2)
52 {
53     TLSSecureOptions secureOption;
54     std::vector<std::string> crlChain = {};
55     secureOption.SetCrlChain(crlChain);
56     auto caChain = secureOption.GetCrlChain();
57     EXPECT_TRUE(crlChain == caChain);
58 
59     VerifyMode verifyMode = VerifyMode::ONE_WAY_MODE;
60     secureOption.SetVerifyMode(verifyMode);
61     auto mode = secureOption.GetVerifyMode();
62     EXPECT_EQ(mode, verifyMode);
63 
64     TLSConnectOptions connectOptions;
65     CheckServerIdentity checkServerIdentity;
66     connectOptions.SetCheckServerIdentity(checkServerIdentity);
67     auto identity = connectOptions.GetCheckServerIdentity();
68     EXPECT_TRUE(identity == nullptr);
69 
70     sockaddr *addr = nullptr;
71     TLSSocket server;
72     auto testString = server.MakeAddressString(addr);
73     EXPECT_TRUE(testString.empty());
74 
75     sockaddr addrInfo;
76     addrInfo.sa_family = 0;
77     testString = server.MakeAddressString(&addrInfo);
78     EXPECT_TRUE(testString.empty());
79 
80     addrInfo.sa_family = AF_INET;
81     testString = server.MakeAddressString(&addrInfo);
82     EXPECT_FALSE(testString.empty());
83 
84     addrInfo.sa_family = AF_INET6;
85     testString = server.MakeAddressString(&addrInfo);
86     EXPECT_FALSE(testString.empty());
87 
88     Socket::NetAddress address;
89     sockaddr_in6 addr6 = { 0 };
90     sockaddr_in addr4 = { 0 };
91     socklen_t len;
92     server.GetAddr(address, &addr4, &addr6, &addr, &len);
93 }
94 
95 HWTEST_F(TLSSecureOptionsBranchTest, TLSSecureOptionsBranchTest002, testing::ext::TestSize.Level2)
96 {
97     TLSSocket server;
98     sa_family_t family = 0;
99     server.MakeIpSocket(family);
100 
101     family = AF_INET;
102     server.MakeIpSocket(family);
103 
104     family = AF_INET6;
105     server.MakeIpSocket(family);
106 
107     std::string data = "";
108     Socket::SocketRemoteInfo remoteInfo;
109     server.CallOnMessageCallback(data, remoteInfo);
110     server.CallOnConnectCallback();
111     server.CallOnCloseCallback();
112 
113     int32_t err = 0;
114     BindCallback bindCallback;
115     server.CallBindCallback(err, bindCallback);
116 
117     ConnectCallback connectCallback;
118     server.CallConnectCallback(err, connectCallback);
119 
120     CloseCallback closeCallback;
121     server.CallCloseCallback(err, closeCallback);
122 
123     Socket::NetAddress address;
124     GetRemoteAddressCallback addressCallback;
125     server.CallGetRemoteAddressCallback(err, address, addressCallback);
126 
127     Socket::SocketStateBase state;
128     GetStateCallback stateCallback;
129     server.CallGetStateCallback(err, state, stateCallback);
130 
131     SetExtraOptionsCallback optionsCallback;
132     server.CallSetExtraOptionsCallback(err, optionsCallback);
133 
134     X509CertRawData cert;
135     GetCertificateCallback certificateCallback;
136     server.CallGetCertificateCallback(err, cert, certificateCallback);
137 
138     GetRemoteCertificateCallback remoteCertificateCallback;
139     server.CallGetRemoteCertificateCallback(err, cert, remoteCertificateCallback);
140 
141     std::string protocol = "";
142     GetProtocolCallback protocolCallback;
143     server.CallGetProtocolCallback(err, protocol, protocolCallback);
144 
145     OnMessageCallback onMessageCallback;
146     server.OnMessage(onMessageCallback);
147     server.OffMessage();
148     EXPECT_TRUE(server.onMessageCallback_ == nullptr);
149 }
150 
151 HWTEST_F(TLSSecureOptionsBranchTest, TLSSecureOptionsBranchTest003, testing::ext::TestSize.Level2)
152 {
153     TLSSocket server;
154     int32_t err = 0;
155     std::vector<std::string> suite = {};
156     GetCipherSuiteCallback cipherSuiteCallback;
157     server.CallGetCipherSuiteCallback(err, suite, cipherSuiteCallback);
158 
159     GetSignatureAlgorithmsCallback algorithmsCallback;
160     server.CallGetSignatureAlgorithmsCallback(err, suite, algorithmsCallback);
161 
162     Socket::NetAddress address;
163     BindCallback bindCallback;
164     server.Bind(address, bindCallback);
165 
166     GetRemoteAddressCallback addressCallback;
167     server.GetRemoteAddress(addressCallback);
168     server.GetIp4RemoteAddress(addressCallback);
169     server.GetIp6RemoteAddress(addressCallback);
170 
171     GetStateCallback stateCallback;
172     server.GetState(stateCallback);
173 
174     Socket::ExtraOptionsBase option;
175     bool ret = server.SetBaseOptions(option);
176     EXPECT_TRUE(ret);
177 
178     TlsSocket::SetExtraOptionsCallback optionsCallback;
179     Socket::TCPExtraOptions tcpExtraOptions;
180     ret = server.SetExtraOptions(tcpExtraOptions);
181     EXPECT_TRUE(ret);
182 
183     tcpExtraOptions.SetKeepAlive(true);
184     ret = server.SetExtraOptions(tcpExtraOptions);
185     EXPECT_TRUE(ret);
186 
187     tcpExtraOptions.SetOOBInline(true);
188     ret = server.SetExtraOptions(tcpExtraOptions);
189     EXPECT_TRUE(ret);
190 
191     tcpExtraOptions.SetTCPNoDelay(true);
192     ret = server.SetExtraOptions(tcpExtraOptions);
193     EXPECT_TRUE(ret);
194 
195     OnConnectCallback onConnectCallback;
196     server.OffConnect();
197     server.OnConnect(onConnectCallback);
198     EXPECT_TRUE(server.onConnectCallback_ == nullptr);
199 }
200 
201 HWTEST_F(TLSSecureOptionsBranchTest, TLSSecureOptionsBranchTest004, testing::ext::TestSize.Level2)
202 {
203     TLSSocket server;
204     server.OffError();
205 
206     TLSConfiguration tLSConfiguration;
207     TLSSocket::TLSSocketInternal internal;
208     tLSConfiguration = internal.GetTlsConfiguration();
209     std::vector<std::string> certificate;
210     tLSConfiguration.SetCaCertificate(certificate);
211     EXPECT_TRUE(tLSConfiguration.GetCaCertificate().empty());
212 
213     TLSConnectOptions connectOptions;
214     auto ret = internal.StartTlsConnected(connectOptions);
215     EXPECT_FALSE(ret);
216 
217     ret = internal.CreatTlsContext();
218     EXPECT_TRUE(ret);
219 
220     ret = internal.StartShakingHands(connectOptions);
221     EXPECT_FALSE(ret);
222 
223     ret = internal.GetRemoteCertificateFromPeer();
224     EXPECT_FALSE(ret);
225 
226     ret = internal.SetRemoteCertRawData();
227     EXPECT_FALSE(ret);
228 
229     OnCloseCallback onCloseCallback;
230     server.OnClose(onCloseCallback);
231     server.OffClose();
232     EXPECT_TRUE(server.onCloseCallback_ == nullptr);
233 }
234 
235 HWTEST_F(TLSSecureOptionsBranchTest, TLSSecureOptionsBranchTest005, testing::ext::TestSize.Level2)
236 {
237     TLSConnectOptions connectOptions;
238     TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
239     tlsSocketInternal->ssl_ = nullptr;
240     bool ret = tlsSocketInternal->StartShakingHands(connectOptions);
241     NETSTACK_LOGI("TLSSecureOptionsBranchTest005 StartShakingHands = %{public}s", std::to_string(ret).c_str());
242     EXPECT_FALSE(ret);
243 }
244 } // namespace TlsSocket
245 } // namespace NetStack
246 } // namespace OHOS
247