1 /*
2  * Copyright (C) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <gtest/gtest.h>
17 #include <thread>
18 
19 #ifdef GTEST_API_
20 #define private public
21 #define protected public
22 #endif
23 
24 #include "mdns_client.h"
25 #include "mdns_client_resume.h"
26 #include "mdns_common.h"
27 #include "mdns_event_stub.h"
28 #include "mock_i_discovery_callback_test.h"
29 #include "net_conn_client.h"
30 #include "netmanager_ext_test_security.h"
31 #include "netmgr_ext_log_wrapper.h"
32 #include "refbase.h"
33 
34 namespace OHOS {
35 namespace NetManagerStandard {
36 using namespace testing::ext;
37 
38 constexpr int DEMO_PORT = 12345;
39 constexpr int DEMO_PORT1 = 23456;
40 constexpr int TIME_ONE_MS = 1;
41 constexpr int TIME_TWO_MS = 2;
42 constexpr int TIME_FOUR_MS = 4;
43 constexpr int TIME_FIVE_MS = 5;
44 constexpr const char *DEMO_NAME = "ala";
45 constexpr const char *DEMO_NAME1 = "ala1";
46 constexpr const char *DEMO_TYPE = "_hellomdns._tcp";
47 
48 static const TxtRecord g_txt{{"key", {'v', 'a', 'l', 'u', 'e'}}, {"null", {'\0'}}};
49 
50 enum class EventType {
51     UNKNOWN,
52     REGISTER,
53     FOUND,
54     LOST,
55     RESOLVE,
56 };
57 
58 std::mutex g_mtx;
59 std::condition_variable g_cv;
60 int g_register = 0;
61 int g_found = 0;
62 int g_lost = 0;
63 int g_resolve = 0;
64 
65 class MDnsTestRegistrationCallback : public RegistrationCallbackStub {
66 public:
MDnsTestRegistrationCallback(const MDnsServiceInfo & info)67     explicit MDnsTestRegistrationCallback(const MDnsServiceInfo &info) : expected_(info) {}
68     virtual ~MDnsTestRegistrationCallback() = default;
HandleRegister(const MDnsServiceInfo & info,int32_t retCode)69     void HandleRegister(const MDnsServiceInfo &info, int32_t retCode) override {}
HandleUnRegister(const MDnsServiceInfo & info,int32_t retCode)70     void HandleUnRegister(const MDnsServiceInfo &info, int32_t retCode) override {}
HandleRegisterResult(const MDnsServiceInfo & info,int32_t retCode)71     void HandleRegisterResult(const MDnsServiceInfo &info, int32_t retCode) override
72     {
73         g_mtx.lock();
74         EXPECT_EQ(retCode, NETMANAGER_EXT_SUCCESS);
75         std::cerr << "registered instance " << info.name + MDNS_DOMAIN_SPLITER_STR + info.type << "\n";
76         EXPECT_EQ(expected_.name, info.name);
77         EXPECT_EQ(expected_.type, info.type);
78         EXPECT_EQ(expected_.port, info.port);
79         g_register++;
80         g_mtx.unlock();
81         g_cv.notify_one();
82     }
83     MDnsServiceInfo expected_;
84 };
85 
86 class MDnsTestDiscoveryCallback : public DiscoveryCallbackStub {
87 public:
MDnsTestDiscoveryCallback(const std::vector<MDnsServiceInfo> & info)88     explicit MDnsTestDiscoveryCallback(const std::vector<MDnsServiceInfo> &info) : expected_(info) {}
89     virtual ~MDnsTestDiscoveryCallback() = default;
HandleStartDiscover(const MDnsServiceInfo & info,int32_t retCode)90     void HandleStartDiscover(const MDnsServiceInfo &info, int32_t retCode) override {}
HandleStopDiscover(const MDnsServiceInfo & info,int32_t retCode)91     void HandleStopDiscover(const MDnsServiceInfo &info, int32_t retCode) override {}
HandleServiceFound(const MDnsServiceInfo & info,int32_t retCode)92     void HandleServiceFound(const MDnsServiceInfo &info, int32_t retCode) override
93     {
94         g_mtx.lock();
95         EXPECT_EQ(retCode, NETMANAGER_EXT_SUCCESS);
96         std::cerr << "found instance " << info.name + MDNS_DOMAIN_SPLITER_STR + info.type << "\n";
97         EXPECT_TRUE(std::find_if(expected_.begin(), expected_.end(),
98                                  [&](auto const &x) { return x.name == info.name; }) != expected_.end());
99         EXPECT_TRUE(std::find_if(expected_.begin(), expected_.end(),
100                                  [&](auto const &x) { return x.type == info.type; }) != expected_.end());
101         g_found++;
102         g_mtx.unlock();
103         g_cv.notify_one();
104     }
105 
HandleServiceLost(const MDnsServiceInfo & info,int32_t retCode)106     void HandleServiceLost(const MDnsServiceInfo &info, int32_t retCode) override
107     {
108         g_mtx.lock();
109         EXPECT_EQ(retCode, NETMANAGER_EXT_SUCCESS);
110         std::cerr << "lost instance " << info.name + MDNS_DOMAIN_SPLITER_STR + info.type << "\n";
111         EXPECT_TRUE(std::find_if(expected_.begin(), expected_.end(),
112                                  [&](auto const &x) { return x.name == info.name; }) != expected_.end());
113         EXPECT_TRUE(std::find_if(expected_.begin(), expected_.end(),
114                                  [&](auto const &x) { return x.type == info.type; }) != expected_.end());
115         g_lost++;
116         g_mtx.unlock();
117         g_cv.notify_one();
118     }
119     std::vector<MDnsServiceInfo> expected_;
120 };
121 
122 class MDnsTestResolveCallback : public ResolveCallbackStub {
123 public:
MDnsTestResolveCallback(const MDnsServiceInfo & info)124     explicit MDnsTestResolveCallback(const MDnsServiceInfo &info) : expected_(info) {}
125     virtual ~MDnsTestResolveCallback() = default;
HandleResolveResult(const MDnsServiceInfo & info,int32_t retCode)126     void HandleResolveResult(const MDnsServiceInfo &info, int32_t retCode) override
127     {
128         g_mtx.lock();
129         EXPECT_EQ(retCode, NETMANAGER_EXT_SUCCESS);
130         std::cerr << "resolved instance " << info.addr + MDNS_HOSTPORT_SPLITER_STR + std::to_string(info.port) << "\n";
131         EXPECT_EQ(expected_.name, info.name);
132         EXPECT_EQ(expected_.type, info.type);
133         EXPECT_EQ(expected_.port, info.port);
134         EXPECT_EQ(expected_.txtRecord, info.txtRecord);
135         g_resolve++;
136         g_mtx.unlock();
137         g_cv.notify_one();
138     }
139     MDnsServiceInfo expected_;
140 };
141 
142 class MDnsClientResumeTest : public testing::Test {
143 public:
144     static void SetUpTestCase();
145     static void TearDownTestCase();
146     void SetUp() override;
147     void TearDown() override;
148 };
149 
SetUpTestCase()150 void MDnsClientResumeTest::SetUpTestCase() {}
151 
TearDownTestCase()152 void MDnsClientResumeTest::TearDownTestCase() {}
153 
SetUp()154 void MDnsClientResumeTest::SetUp() {}
155 
TearDown()156 void MDnsClientResumeTest::TearDown() {}
157 
158 class MDnsClientTest : public testing::Test {
159 public:
160     static void SetUpTestCase();
161     static void TearDownTestCase();
162     void SetUp() override;
163     void TearDown() override;
164 };
165 
SetUpTestCase()166 void MDnsClientTest::SetUpTestCase() {}
167 
TearDownTestCase()168 void MDnsClientTest::TearDownTestCase() {}
169 
SetUp()170 void MDnsClientTest::SetUp() {}
171 
TearDown()172 void MDnsClientTest::TearDown() {}
173 
174 class MDnsServerTest : public testing::Test {
175 public:
176     static void SetUpTestCase();
177     static void TearDownTestCase();
178     void SetUp() override;
179     void TearDown() override;
180 };
181 
SetUpTestCase()182 void MDnsServerTest::SetUpTestCase() {}
183 
TearDownTestCase()184 void MDnsServerTest::TearDownTestCase() {}
185 
SetUp()186 void MDnsServerTest::SetUp() {}
187 
TearDown()188 void MDnsServerTest::TearDown() {}
189 
190 
191 struct MdnsClientTestParams {
192     MDnsServiceInfo info;
193     MDnsServiceInfo infoBack;
194     sptr<MDnsTestRegistrationCallback> registration;
195     sptr<MDnsTestRegistrationCallback> registrationBack;
196     sptr<MDnsTestDiscoveryCallback> discovery;
197     sptr<MDnsTestDiscoveryCallback> discoveryBack;
198     sptr<MDnsTestResolveCallback> resolve;
199     sptr<MDnsTestResolveCallback> resolveBack;
200 };
201 
DoTestForMdnsClient(MdnsClientTestParams param)202 void DoTestForMdnsClient(MdnsClientTestParams param)
203 {
204     NetManagerExtAccessToken token;
205     bool flag = false;
206     NetConnClient::GetInstance().HasDefaultNet(flag);
207     if (!flag) {
208         return;
209     }
210     std::unique_lock<std::mutex> lock(g_mtx);
211     DelayedSingleton<MDnsClient>::GetInstance()->RegisterService(param.info, param.registration);
212     DelayedSingleton<MDnsClient>::GetInstance()->RegisterService(param.infoBack, param.registrationBack);
213     if (!g_cv.wait_for(lock, std::chrono::seconds(TIME_FIVE_MS), []() { return g_register == TIME_TWO_MS; })) {
214         FAIL();
215     }
216     DelayedSingleton<MDnsClient>::GetInstance()->StartDiscoverService(param.info.type, param.discovery);
217     DelayedSingleton<MDnsClient>::GetInstance()->StartDiscoverService(param.info.type, param.discoveryBack);
218     if (!g_cv.wait_for(lock, std::chrono::seconds(TIME_FIVE_MS), []() { return g_found >= TIME_FOUR_MS; })) {
219         FAIL();
220     }
221     DelayedSingleton<MDnsClient>::GetInstance()->ResolveService(param.info, param.resolve);
222     if (!g_cv.wait_for(lock, std::chrono::seconds(TIME_FIVE_MS), []() { return g_resolve >= TIME_ONE_MS; })) {
223         FAIL();
224     }
225     DelayedSingleton<MDnsClient>::GetInstance()->ResolveService(param.infoBack, param.resolveBack);
226     if (!g_cv.wait_for(lock, std::chrono::seconds(TIME_FIVE_MS), []() { return g_resolve >= TIME_TWO_MS; })) {
227         FAIL();
228     }
229     DelayedSingleton<MDnsClient>::GetInstance()->StopDiscoverService(param.discovery);
230     DelayedSingleton<MDnsClient>::GetInstance()->StopDiscoverService(param.discoveryBack);
231     if (!g_cv.wait_for(lock, std::chrono::seconds(TIME_FIVE_MS), []() { return g_lost >= TIME_ONE_MS; })) {
232         FAIL();
233     }
234     DelayedSingleton<MDnsClient>::GetInstance()->UnRegisterService(param.registration);
235     DelayedSingleton<MDnsClient>::GetInstance()->UnRegisterService(param.registrationBack);
236 
237     std::this_thread::sleep_for(std::chrono::seconds(TIME_ONE_MS));
238 
239     DelayedSingleton<MDnsClient>::GetInstance()->RestartResume();
240 }
241 
242 HWTEST_F(MDnsClientResumeTest, ResumeTest001, TestSize.Level1)
243 {
244     MDnsServiceInfo info;
245     MDnsServiceInfo infoBack;
246     info.name = DEMO_NAME;
247     info.type = DEMO_TYPE;
248     info.port = DEMO_PORT;
249     info.SetAttrMap(g_txt);
250 
251     sptr<MDnsTestRegistrationCallback> registration(new (std::nothrow) MDnsTestRegistrationCallback(info));
252     sptr<MDnsTestDiscoveryCallback> discovery(new (std::nothrow) MDnsTestDiscoveryCallback({info, infoBack}));
253     ASSERT_NE(registration, nullptr);
254     ASSERT_NE(discovery, nullptr);
255 
256     int32_t ret = MDnsClientResume::GetInstance().SaveRegisterService(info, registration);
257     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
258 
259     ret = MDnsClientResume::GetInstance().SaveRegisterService(info, registration);
260     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
261 
262     ret = MDnsClientResume::GetInstance().SaveStartDiscoverService(info.type, discovery);
263     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
264 
265     ret = MDnsClientResume::GetInstance().SaveStartDiscoverService(info.type, discovery);
266     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
267 
268     ret = MDnsClientResume::GetInstance().RemoveRegisterService(registration);
269     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
270 
271     ret = MDnsClientResume::GetInstance().RemoveRegisterService(registration);
272     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
273 
274     ret = MDnsClientResume::GetInstance().RemoveStopDiscoverService(discovery);
275     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
276 
277     ret = MDnsClientResume::GetInstance().RemoveStopDiscoverService(discovery);
278     EXPECT_EQ(ret, NETMANAGER_EXT_SUCCESS);
279 
280     RegisterServiceMap *rsm = MDnsClientResume::GetInstance().GetRegisterServiceMap();
281     ASSERT_NE(rsm, nullptr);
282 
283     DiscoverServiceMap *dsm = MDnsClientResume::GetInstance().GetStartDiscoverServiceMap();
284     ASSERT_NE(dsm, nullptr);
285 }
286 
287 /**
288  * @tc.name: ServiceTest001
289  * @tc.desc: Test mDNS register and found.
290  * @tc.type: FUNC
291  */
292 HWTEST_F(MDnsClientTest, ClientTest001, TestSize.Level1)
293 {
294     MDnsServiceInfo info;
295     MDnsServiceInfo infoBack;
296     info.name = DEMO_NAME;
297     info.type = DEMO_TYPE;
298     info.port = DEMO_PORT;
299     info.SetAttrMap(g_txt);
300     infoBack = info;
301     infoBack.name = DEMO_NAME1;
302     infoBack.port = DEMO_PORT1;
303 
304     auto client = DelayedSingleton<MDnsClient>::GetInstance();
305     sptr<MDnsTestRegistrationCallback> registration(new (std::nothrow) MDnsTestRegistrationCallback(info));
306     sptr<MDnsTestRegistrationCallback> registrationBack(new (std::nothrow) MDnsTestRegistrationCallback(infoBack));
307     sptr<MDnsTestDiscoveryCallback> discovery(new (std::nothrow) MDnsTestDiscoveryCallback({info, infoBack}));
308     sptr<MDnsTestDiscoveryCallback> discoveryBack(new (std::nothrow) MDnsTestDiscoveryCallback({info, infoBack}));
309     sptr<MDnsTestResolveCallback> resolve(new (std::nothrow) MDnsTestResolveCallback(info));
310     sptr<MDnsTestResolveCallback> resolveBack(new (std::nothrow) MDnsTestResolveCallback(infoBack));
311     ASSERT_NE(registration, nullptr);
312     ASSERT_NE(registrationBack, nullptr);
313     ASSERT_NE(discovery, nullptr);
314     ASSERT_NE(discoveryBack, nullptr);
315     ASSERT_NE(resolve, nullptr);
316     ASSERT_NE(resolveBack, nullptr);
317 
318     MdnsClientTestParams mdnsClientTestParams;
319     mdnsClientTestParams.info = info;
320     mdnsClientTestParams.infoBack = infoBack;
321     mdnsClientTestParams.registration = registration;
322     mdnsClientTestParams.registrationBack = registrationBack;
323     mdnsClientTestParams.discovery = discovery;
324     mdnsClientTestParams.discoveryBack = discoveryBack;
325     mdnsClientTestParams.resolve = resolve;
326     mdnsClientTestParams.resolveBack = resolveBack;
327     DoTestForMdnsClient(mdnsClientTestParams);
328 }
329 
330 HWTEST_F(MDnsServerTest, ServerTest, TestSize.Level1)
331 {
332     MDnsServiceInfo info;
333     info.name = DEMO_NAME;
334     info.type = DEMO_TYPE;
335     info.port = DEMO_PORT;
336     TxtRecord txt{};
337     info.SetAttrMap(txt);
338     info.SetAttrMap(g_txt);
339     auto retMap = info.GetAttrMap();
340     EXPECT_NE(retMap.size(), 0);
341 
342     MessageParcel parcel;
343     auto retMar = info.Marshalling(parcel);
344     EXPECT_EQ(retMar, true);
345 
346     auto serviceInfo = info.Unmarshalling(parcel);
347     retMar = info.Marshalling(parcel, serviceInfo);
348     EXPECT_EQ(retMar, true);
349 }
350 
351 /**
352  * @tc.name: MDnsCommonTest001
353  * @tc.desc: Test MDnsServerTest
354  * @tc.type: FUNC
355  */
356 HWTEST_F(MDnsServerTest, MDnsCommonTest001, TestSize.Level1)
357 {
358     std::string testStr = "abbcccddddcccbba";
359     for (size_t i = 0; i < testStr.size(); ++i) {
360     EXPECT_TRUE(EndsWith(testStr, testStr.substr(i)));
361     }
362 
363     for (size_t i = 0; i < testStr.size(); ++i) {
364     EXPECT_TRUE(StartsWith(testStr, testStr.substr(0, testStr.size() - i)));
365     }
366 
367     auto lhs = Split(testStr, 'c');
368     auto rhs = std::vector<std::string_view>{
369         "abb",
370         "dddd",
371         "bba",
372     };
373     EXPECT_EQ(lhs, rhs);
374 }
375 
376 /**
377  * @tc.name: MDnsCommonTest002
378  * @tc.desc: Test MDnsServerTest
379  * @tc.type: FUNC
380  */
381 HWTEST_F(MDnsServerTest, MDnsCommonTest002, TestSize.Level1)
382 {
383     constexpr size_t isNameIndex = 1;
384     constexpr size_t isTypeIndex = 2;
385     constexpr size_t isInstanceIndex = 3;
386     constexpr size_t isDomainIndex = 4;
387     std::vector<std::tuple<std::string, bool, bool, bool, bool>> test = {
388         {"abbcccddddcccbba", true,  false, false, true },
389         {"",                 false, false, false, true },
390         {"a.b",              false, false, false, true },
391         {"_xxx.tcp",         false, false, false, true },
392         {"xxx._tcp",         false, false, false, true },
393         {"xxx.yyy",          false, false, false, true },
394         {"xxx.yyy",          false, false, false, true },
395         {"_xxx._yyy",        false, false, false, true },
396         {"hello._ipp._tcp",  false, false, true,  true },
397         {"_x._y._tcp",       false, false, true,  true },
398         {"_ipp._tcp",        false, true,  false, true },
399         {"_http._tcp",       false, true,  false, true },
400     };
401 
402     for (auto line : test) {
403         EXPECT_EQ(IsNameValid(std::get<0>(line)), std::get<isNameIndex>(line));
404         EXPECT_EQ(IsTypeValid(std::get<0>(line)), std::get<isTypeIndex>(line));
405         EXPECT_EQ(IsInstanceValid(std::get<0>(line)), std::get<isInstanceIndex>(line));
406         EXPECT_EQ(IsDomainValid(std::get<0>(line)), std::get<isDomainIndex>(line));
407     }
408 
409     EXPECT_TRUE(IsPortValid(22));
410     EXPECT_TRUE(IsPortValid(65535));
411     EXPECT_TRUE(IsPortValid(0));
412     EXPECT_FALSE(IsPortValid(-1));
413     EXPECT_FALSE(IsPortValid(65536));
414 }
415 
416 /**
417  * @tc.name: MDnsCommonTest003
418  * @tc.desc: Test MDnsServerTest
419  * @tc.type: FUNC
420  */
421 HWTEST_F(MDnsServerTest, MDnsCommonTest003, TestSize.Level1)
422 {
423     std::string instance = "hello._ipp._tcp";
424     std::string instance1 = "_x._y._tcp";
425     std::string name;
426     std::string type;
427     ExtractNameAndType(instance, name, type);
428     EXPECT_EQ(name, "hello");
429     EXPECT_EQ(type, "_ipp._tcp");
430 
431     ExtractNameAndType(instance1, name, type);
432     EXPECT_EQ(name, "_x");
433     EXPECT_EQ(type, "_y._tcp");
434 }
435 
436 HWTEST_F(MDnsServerTest, MDnsServerBranchTest001, TestSize.Level1)
437 {
438     std::string serviceType = "test";
439     sptr<IDiscoveryCallback> callback = new (std::nothrow) MockIDiscoveryCallbackTest();
440     EXPECT_TRUE(callback != nullptr);
441     if (callback == nullptr) {
442         return;
443     }
444     auto ret = DelayedSingleton<MDnsClient>::GetInstance()->StartDiscoverService(serviceType, callback);
445     EXPECT_EQ(ret, NET_MDNS_ERR_ILLEGAL_ARGUMENT);
446 
447     ret = DelayedSingleton<MDnsClient>::GetInstance()->StopDiscoverService(callback);
448     EXPECT_EQ(ret, NET_MDNS_ERR_CALLBACK_NOT_FOUND);
449 
450     callback = nullptr;
451     ret = DelayedSingleton<MDnsClient>::GetInstance()->StartDiscoverService(serviceType, callback);
452     EXPECT_EQ(ret, NET_MDNS_ERR_ILLEGAL_ARGUMENT);
453 
454     ret = DelayedSingleton<MDnsClient>::GetInstance()->StopDiscoverService(callback);
455     EXPECT_EQ(ret, NET_MDNS_ERR_ILLEGAL_ARGUMENT);
456 }
457 } // namespace NetManagerStandard
458 } // namespace OHOS