1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mdns_protocol_impl.h"
17 
18 #include <arpa/inet.h>
19 #include <cstddef>
20 #include <iostream>
21 #include <random>
22 #include <sys/types.h>
23 #include <unistd.h>
24 #include <fcntl.h>
25 
26 #include "mdns_manager.h"
27 #include "mdns_packet_parser.h"
28 #include "net_conn_client.h"
29 #include "netmgr_ext_log_wrapper.h"
30 
31 #include "securec.h"
32 
33 namespace OHOS {
34 namespace NetManagerStandard {
35 
36 constexpr uint32_t DEFAULT_INTEVAL_MS = 2000;
37 constexpr uint32_t DEFAULT_LOST_MS = 10000;
38 constexpr uint32_t DEFAULT_TTL = 120;
39 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
40 
41 constexpr int PHASE_PTR = 1;
42 constexpr int PHASE_SRV = 2;
43 constexpr int PHASE_DOMAIN = 3;
44 
AddrToString(const std::any & addr)45 std::string AddrToString(const std::any &addr)
46 {
47     char buf[INET6_ADDRSTRLEN] = {0};
48     if (std::any_cast<in_addr>(&addr)) {
49         if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
50             return std::string{};
51         }
52     } else if (std::any_cast<in6_addr>(&addr)) {
53         if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
54             return std::string{};
55         }
56     }
57     return std::string(buf);
58 }
59 
MilliSecondsSinceEpoch()60 int64_t MilliSecondsSinceEpoch()
61 {
62     return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
63         .count();
64 }
65 
MDnsProtocolImpl()66 MDnsProtocolImpl::MDnsProtocolImpl()
67 {
68     Init();
69 }
70 
Init()71 void MDnsProtocolImpl::Init()
72 {
73     NETMGR_EXT_LOG_D("mdns_log MDnsProtocolImpl init");
74     listener_.Stop();
75     listener_.CloseAllSocket();
76 
77     if (config_.configAllIface) {
78         listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
79     } else {
80         listener_.OpenSocketForDefault(config_.ipv6Support);
81     }
82     listener_.SetReceiveHandler(
83         [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
84     listener_.SetFinishedHandler([this](int sock) {
85         std::lock_guard<std::recursive_mutex> guard(mutex_);
86         RunTaskQueue(taskQueue_);
87     });
88     listener_.Start();
89 
90     taskQueue_.clear();
91     taskOnChange_.clear();
92     AddTask([this]() { return Browse(); }, false);
93 }
94 
Browse()95 bool MDnsProtocolImpl::Browse()
96 {
97     if (lastRunTime != -1 && MilliSecondsSinceEpoch() - lastRunTime < DEFAULT_INTEVAL_MS) {
98         return false;
99     }
100     lastRunTime = MilliSecondsSinceEpoch();
101     std::lock_guard<std::recursive_mutex> guard(mutex_);
102     for (auto &&[key, res] : browserMap_) {
103         NETMGR_EXT_LOG_D("mdns_log Browse browserMap_ key[%{public}s] res.size[%{public}zu]", key.c_str(), res.size());
104         if (nameCbMap_.find(key) != nameCbMap_.end() &&
105             !MDnsManager::GetInstance().IsAvailableCallback(nameCbMap_[key])) {
106             continue;
107         }
108         handleOfflineService(key, res);
109         MDnsPayloadParser parser;
110         MDnsMessage msg{};
111         msg.questions.emplace_back(DNSProto::Question{
112             .name = key,
113             .qtype = DNSProto::RRTYPE_PTR,
114             .qclass = DNSProto::RRCLASS_IN,
115         });
116         listener_.MulticastAll(parser.ToBytes(msg));
117     }
118     return false;
119 }
120 
ConnectControl(int32_t sockfd,sockaddr * serverAddr)121 int32_t MDnsProtocolImpl::ConnectControl(int32_t sockfd, sockaddr* serverAddr)
122 {
123     uint32_t flags = static_cast<uint32_t>(fcntl(sockfd, F_GETFL, 0));
124     fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
125     int32_t ret = connect(sockfd, serverAddr, sizeof(sockaddr));
126     if ((ret < 0) && (errno != EINPROGRESS)) {
127         NETMGR_EXT_LOG_E("connect error: %{public}d", errno);
128         return NETMANAGER_EXT_ERR_INTERNAL;
129     }
130     if (ret == 0) {
131         fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
132         NETMGR_EXT_LOG_I("connect success.");
133         return NETMANAGER_EXT_SUCCESS;
134     }
135 
136     fd_set rset;
137     FD_ZERO(&rset);
138     FD_SET(sockfd, &rset);
139     fd_set wset = rset;
140     timeval tval {1, 0};
141     ret = select(sockfd + 1, &rset, &wset, NULL, &tval);
142     if (ret < 0) { // select error.
143         NETMGR_EXT_LOG_E("select error: %{public}d", errno);
144         return NETMANAGER_EXT_ERR_INTERNAL;
145     }
146     if (ret == 0) { // timeout
147         NETMGR_EXT_LOG_E("connect timeout...");
148         return NETMANAGER_EXT_ERR_INTERNAL;
149     }
150     if (!FD_ISSET(sockfd, &rset) && !FD_ISSET(sockfd, &wset)) {
151         NETMGR_EXT_LOG_E("select error: sockfd not set");
152         return NETMANAGER_EXT_ERR_INTERNAL;
153     }
154 
155     int32_t result = NETMANAGER_EXT_ERR_INTERNAL;
156     socklen_t len = sizeof(result);
157     if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &result, &len) < 0) {
158         NETMGR_EXT_LOG_E("getsockopt error: %{public}d", errno);
159         return NETMANAGER_EXT_ERR_INTERNAL;
160     }
161     if (result != 0) { // connect failed.
162         NETMGR_EXT_LOG_E("connect failed. error: %{public}d", result);
163         return NETMANAGER_EXT_ERR_INTERNAL;
164     }
165     fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
166     NETMGR_EXT_LOG_I("lost but connect success.");
167     return NETMANAGER_EXT_SUCCESS;
168 }
169 
IsConnectivity(const std::string & ip,int32_t port)170 bool MDnsProtocolImpl::IsConnectivity(const std::string &ip, int32_t port)
171 {
172     if (ip.empty()) {
173         NETMGR_EXT_LOG_E("ip is empty");
174         return false;
175     }
176 
177     int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
178     if (sockfd < 0) {
179         NETMGR_EXT_LOG_E("create socket error: %{public}d", errno);
180         return false;
181     }
182 
183     struct sockaddr_in serverAddr;
184     if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
185         NETMGR_EXT_LOG_E("memset_s serverAddr failed!");
186         close(sockfd);
187         return false;
188     }
189 
190     serverAddr.sin_family = AF_INET;
191     serverAddr.sin_addr.s_addr = inet_addr(ip.c_str());
192     serverAddr.sin_port = htons(port);
193     if (ConnectControl(sockfd, (struct sockaddr*)&serverAddr) != NETMANAGER_EXT_SUCCESS) {
194         NETMGR_EXT_LOG_I("connect error: %{public}d", errno);
195         close(sockfd);
196         return false;
197     }
198 
199     close(sockfd);
200     return true;
201 }
202 
handleOfflineService(const std::string & key,std::vector<Result> & res)203 void MDnsProtocolImpl::handleOfflineService(const std::string &key, std::vector<Result> &res)
204 {
205     NETMGR_EXT_LOG_D("mdns_log handleOfflineService key:[%{public}s]", key.c_str());
206     for (auto it = res.begin(); it != res.end();) {
207         if (lastRunTime - it->refrehTime > DEFAULT_LOST_MS && it->state == State::LIVE) {
208             std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
209             if ((cacheMap_.find(fullName) != cacheMap_.end()) &&
210                 IsConnectivity(cacheMap_[fullName].addr, cacheMap_[fullName].port)) {
211                 it++;
212                 continue;
213             }
214 
215             it->state = State::DEAD;
216             if (nameCbMap_.find(key) != nameCbMap_.end() && nameCbMap_[key] != nullptr) {
217                 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
218                 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
219             }
220             it = res.erase(it);
221             cacheMap_.erase(fullName);
222         } else {
223             it++;
224         }
225     }
226 }
227 
SetConfig(const MDnsConfig & config)228 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
229 {
230     config_ = config;
231 }
232 
GetConfig() const233 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
234 {
235     return config_;
236 }
237 
Decorated(const std::string & name) const238 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
239 {
240     return name + config_.topDomain;
241 }
242 
Register(const Result & info)243 int32_t MDnsProtocolImpl::Register(const Result &info)
244 {
245     NETMGR_EXT_LOG_D("mdns_log Register");
246     if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
247         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
248     }
249     std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
250     if (!IsDomainValid(name)) {
251         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
252     }
253     {
254         std::lock_guard<std::recursive_mutex> guard(mutex_);
255         if (srvMap_.find(name) != srvMap_.end()) {
256             return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
257         }
258         srvMap_.emplace(name, info);
259     }
260     return Announce(info, false);
261 }
262 
UnRegister(const std::string & key)263 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
264 {
265     NETMGR_EXT_LOG_D("mdns_log UnRegister");
266     std::string name = Decorated(key);
267     std::lock_guard<std::recursive_mutex> guard(mutex_);
268     if (srvMap_.find(name) != srvMap_.end()) {
269         Announce(srvMap_[name], true);
270         srvMap_.erase(name);
271         return NETMANAGER_EXT_SUCCESS;
272     }
273     return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
274 }
275 
DiscoveryFromCache(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)276 bool MDnsProtocolImpl::DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
277 {
278     NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache");
279     std::string name = Decorated(serviceType);
280     std::lock_guard<std::recursive_mutex> guard(mutex_);
281     if (!IsBrowserAvailable(name)) {
282         return false;
283     }
284 
285     if (browserMap_.find(name) == browserMap_.end()) {
286         NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache browserMap_ not find name");
287         return false;
288     }
289 
290     for (auto &res : browserMap_[name]) {
291         if (res.state == State::REMOVE || res.state == State::DEAD) {
292             continue;
293         }
294         AddTask([cb, info = ConvertResultToInfo(res)]() {
295             NETMGR_EXT_LOG_W("mdns_log DiscoveryFromCache ConvertResultToInfo HandleServiceFound");
296             if (MDnsManager::GetInstance().IsAvailableCallback(cb)) {
297                 cb->HandleServiceFound(info, NETMANAGER_EXT_SUCCESS);
298             }
299             return true;
300         });
301     }
302     return true;
303 }
304 
DiscoveryFromNet(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)305 bool MDnsProtocolImpl::DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
306 {
307     NETMGR_EXT_LOG_D("mdns_log DiscoveryFromNet");
308     std::string name = Decorated(serviceType);
309     std::lock_guard<std::recursive_mutex> guard(mutex_);
310     browserMap_.insert({name, std::vector<Result>{}});
311     nameCbMap_[name] = cb;
312     // key is serviceTYpe
313     AddEvent(name, [this, name, cb]() {
314         std::lock_guard<std::recursive_mutex> guard(mutex_);
315         if (!IsBrowserAvailable(name)) {
316             return false;
317         }
318         if (!MDnsManager::GetInstance().IsAvailableCallback(cb)) {
319             return true;
320         }
321         for (auto &res : browserMap_[name]) {
322             std::string fullName = Decorated(res.serviceName + MDNS_DOMAIN_SPLITER_STR + res.serviceType);
323             NETMGR_EXT_LOG_W("mdns_log DiscoveryFromNet name:[%{public}s] fullName:[%{public}s]", name.c_str(),
324                              fullName.c_str());
325             if (cacheMap_.find(fullName) == cacheMap_.end() ||
326                 (res.state == State::ADD || res.state == State::REFRESH)) {
327                 NETMGR_EXT_LOG_W("mdns_log HandleServiceFound");
328                 cb->HandleServiceFound(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
329                 res.state = State::LIVE;
330             }
331             if (res.state == State::REMOVE) {
332                 res.state = State::DEAD;
333                 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
334                 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
335                 if (cacheMap_.find(fullName) != cacheMap_.end()) {
336                     res.state = State::ADD;
337                     cacheMap_.erase(fullName);
338                 }
339             }
340         }
341         return false;
342     });
343 
344     AddTask([=]() {
345             MDnsPayloadParser parser;
346             MDnsMessage msg{};
347             msg.questions.emplace_back(DNSProto::Question{
348                 .name = name,
349                 .qtype = DNSProto::RRTYPE_PTR,
350                 .qclass = DNSProto::RRCLASS_IN,
351             });
352             listener_.MulticastAll(parser.ToBytes(msg));
353             return true;
354         }, false);
355     return true;
356 }
357 
Discovery(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)358 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
359 {
360     NETMGR_EXT_LOG_D("mdns_log Discovery");
361     DiscoveryFromCache(serviceType, cb);
362     DiscoveryFromNet(serviceType, cb);
363     return NETMANAGER_EXT_SUCCESS;
364 }
365 
ResolveInstanceFromCache(const std::string & name,const sptr<IResolveCallback> & cb)366 bool MDnsProtocolImpl::ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)
367 {
368     NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromCache");
369     std::lock_guard<std::recursive_mutex> guard(mutex_);
370     if (!IsInstanceCacheAvailable(name)) {
371         NETMGR_EXT_LOG_W("mdns_log ResolveInstanceFromCache cacheMap_ has no element [%{public}s]", name.c_str());
372         return false;
373     }
374 
375     NETMGR_EXT_LOG_I("mdns_log rr.name : [%{public}s]", name.c_str());
376     Result r = cacheMap_[name];
377     if (IsDomainCacheAvailable(r.domain)) {
378         r.ipv6 = cacheMap_[r.domain].ipv6;
379         r.addr = cacheMap_[r.domain].addr;
380 
381         NETMGR_EXT_LOG_D("mdns_log Add Task DomainCache Available, [%{public}s]", r.domain.c_str());
382         AddTask([cb, info = ConvertResultToInfo(r)]() {
383             if (nullptr != cb) {
384                 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
385             }
386             return true;
387         });
388     } else {
389         ResolveFromNet(r.domain, nullptr);
390         NETMGR_EXT_LOG_D("mdns_log Add Event DomainCache UnAvailable, [%{public}s]", r.domain.c_str());
391         AddEvent(r.domain, [this, cb, r]() mutable {
392             if (!IsDomainCacheAvailable(r.domain)) {
393                 return false;
394             }
395             r.ipv6 = cacheMap_[r.domain].ipv6;
396             r.addr = cacheMap_[r.domain].addr;
397             if (nullptr != cb) {
398                 cb->HandleResolveResult(ConvertResultToInfo(r), NETMANAGER_EXT_SUCCESS);
399             }
400             return true;
401         });
402     }
403     return true;
404 }
405 
ResolveInstanceFromNet(const std::string & name,const sptr<IResolveCallback> & cb)406 bool MDnsProtocolImpl::ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)
407 {
408     NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromNet");
409     {
410         std::lock_guard<std::recursive_mutex> guard(mutex_);
411         cacheMap_[name].state = State::ADD;
412         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
413     }
414     MDnsPayloadParser parser;
415     MDnsMessage msg{};
416     msg.questions.emplace_back(DNSProto::Question{
417         .name = name,
418         .qtype = DNSProto::RRTYPE_SRV,
419         .qclass = DNSProto::RRCLASS_IN,
420     });
421     msg.questions.emplace_back(DNSProto::Question{
422         .name = name,
423         .qtype = DNSProto::RRTYPE_TXT,
424         .qclass = DNSProto::RRCLASS_IN,
425     });
426     msg.header.qdcount = msg.questions.size();
427     AddEvent(name, [this, name, cb]() { return ResolveInstanceFromCache(name, cb); });
428     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
429     return size > 0;
430 }
431 
ResolveFromCache(const std::string & domain,const sptr<IResolveCallback> & cb)432 bool MDnsProtocolImpl::ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)
433 {
434     NETMGR_EXT_LOG_D("mdns_log ResolveFromCache");
435     std::lock_guard<std::recursive_mutex> guard(mutex_);
436     if (!IsDomainCacheAvailable(domain)) {
437         return false;
438     }
439     AddTask([this, cb, info = ConvertResultToInfo(cacheMap_[domain])]() {
440         if (nullptr != cb) {
441             cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
442         }
443         return true;
444     });
445     return true;
446 }
447 
ResolveFromNet(const std::string & domain,const sptr<IResolveCallback> & cb)448 bool MDnsProtocolImpl::ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)
449 {
450     NETMGR_EXT_LOG_D("mdns_log ResolveFromNet");
451     {
452         std::lock_guard<std::recursive_mutex> guard(mutex_);
453         cacheMap_[domain];
454         cacheMap_[domain].domain = domain;
455     }
456     MDnsPayloadParser parser;
457     MDnsMessage msg{};
458     msg.questions.emplace_back(DNSProto::Question{
459         .name = domain,
460         .qtype = DNSProto::RRTYPE_A,
461         .qclass = DNSProto::RRCLASS_IN,
462     });
463     msg.questions.emplace_back(DNSProto::Question{
464         .name = domain,
465         .qtype = DNSProto::RRTYPE_AAAA,
466         .qclass = DNSProto::RRCLASS_IN,
467     });
468     // key is serviceName
469     AddEvent(domain, [this, cb, domain]() { return ResolveFromCache(domain, cb); });
470     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
471     return size > 0;
472 }
473 
ResolveInstance(const std::string & instance,const sptr<IResolveCallback> & cb)474 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)
475 {
476     NETMGR_EXT_LOG_D("mdns_log execute ResolveInstance");
477     if (!IsInstanceValid(instance)) {
478         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
479     }
480     std::string name = Decorated(instance);
481     if (!IsDomainValid(name)) {
482         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
483     }
484     if (ResolveInstanceFromCache(name, cb)) {
485         return NETMANAGER_EXT_SUCCESS;
486     }
487     return ResolveInstanceFromNet(name, cb) ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
488 }
489 
Announce(const Result & info,bool off)490 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
491 {
492     NETMGR_EXT_LOG_I("mdns_log Announce message");
493     MDnsMessage response{};
494     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
495     std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
496     response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
497                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
498                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
499                                                            .ttl = off ? 0U : DEFAULT_TTL,
500                                                            .rdata = name});
501     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
502                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
503                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
504                                                            .ttl = off ? 0U : DEFAULT_TTL,
505                                                            .rdata = DNSProto::RDataSrv{
506                                                                .priority = 0,
507                                                                .weight = 0,
508                                                                .port = static_cast<uint16_t>(info.port),
509                                                                .name = GetHostDomain(),
510                                                            }});
511     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
512                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
513                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
514                                                            .ttl = off ? 0U : DEFAULT_TTL,
515                                                            .rdata = info.txt});
516     MDnsPayloadParser parser;
517     ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
518     return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
519 }
520 
ReceivePacket(int sock,const MDnsPayload & payload)521 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
522 {
523     if (payload.size() == 0) {
524         return;
525     }
526     MDnsPayloadParser parser;
527     MDnsMessage msg = parser.FromBytes(payload);
528     if (parser.GetError() != 0) {
529         NETMGR_EXT_LOG_E("parser payload failed");
530         return;
531     }
532     if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
533         ProcessQuestion(sock, msg);
534     } else {
535         ProcessAnswer(sock, msg);
536     }
537 }
538 
AppendRecord(std::vector<DNSProto::ResourceRecord> & rrlist,DNSProto::RRType type,const std::string & name,const std::any & rdata)539 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
540                                     const std::string &name, const std::any &rdata)
541 {
542     rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
543                                                  .rtype = static_cast<uint16_t>(type),
544                                                  .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
545                                                  .ttl = DEFAULT_TTL,
546                                                  .rdata = rdata});
547 }
548 
ProcessQuestion(int sock,const MDnsMessage & msg)549 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
550 {
551     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
552     if (saddrIf == nullptr) {
553         NETMGR_EXT_LOG_W("mdns_log ProcessQuestion saddrIf is null");
554         return;
555     }
556     std::any anyAddr;
557     DNSProto::RRType anyAddrType;
558     if (saddrIf->sa_family == AF_INET6) {
559         anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
560         anyAddrType = DNSProto::RRTYPE_AAAA;
561     } else {
562         anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
563         anyAddrType = DNSProto::RRTYPE_A;
564     }
565     int phase = 0;
566     MDnsMessage response{};
567     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
568     for (size_t i = 0; i < msg.header.qdcount; ++i) {
569         ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
570     }
571     if (phase < PHASE_DOMAIN) {
572         AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
573     }
574 
575     if (phase != 0 && response.answers.size() > 0) {
576         listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
577     }
578 }
579 
ProcessQuestionRecord(const std::any & anyAddr,const DNSProto::RRType & anyAddrType,const DNSProto::Question & qu,int & phase,MDnsMessage & response)580 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
581                                              const DNSProto::Question &qu, int &phase, MDnsMessage &response)
582 {
583     NETMGR_EXT_LOG_D("mdns_log ProcessQuestionRecord");
584     std::lock_guard<std::recursive_mutex> guard(mutex_);
585     std::string name = qu.name;
586     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
587         std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
588             if (EndsWith(elem.first, name)) {
589                 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
590                 AppendRecord(response.additional, DNSProto::RRTYPE_SRV, elem.first,
591                              DNSProto::RDataSrv{
592                                  .priority = 0,
593                                  .weight = 0,
594                                  .port = static_cast<uint16_t>(elem.second.port),
595                                  .name = GetHostDomain(),
596                              });
597                 AppendRecord(response.additional, DNSProto::RRTYPE_TXT, elem.first, elem.second.txt);
598             }
599         });
600         phase = std::max(phase, PHASE_PTR);
601     }
602     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
603         auto iter = srvMap_.find(name);
604         if (iter == srvMap_.end()) {
605             return;
606         }
607         AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
608                      DNSProto::RDataSrv{
609                          .priority = 0,
610                          .weight = 0,
611                          .port = static_cast<uint16_t>(iter->second.port),
612                          .name = GetHostDomain(),
613                      });
614         phase = std::max(phase, PHASE_SRV);
615     }
616     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
617         auto iter = srvMap_.find(name);
618         if (iter == srvMap_.end()) {
619             return;
620         }
621         AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
622         phase = std::max(phase, PHASE_SRV);
623     }
624     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
625         if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
626             return;
627         }
628         AppendRecord(response.answers, anyAddrType, name, anyAddr);
629         phase = std::max(phase, PHASE_DOMAIN);
630     }
631 }
632 
ProcessAnswer(int sock,const MDnsMessage & msg)633 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
634 {
635     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
636     if (saddrIf == nullptr) {
637         return;
638     }
639     bool v6 = (saddrIf->sa_family == AF_INET6);
640     std::set<std::string> changed;
641     for (const auto &answer : msg.answers) {
642         ProcessAnswerRecord(v6, answer, changed);
643     }
644     for (const auto &i : msg.additional) {
645         ProcessAnswerRecord(v6, i, changed);
646     }
647     for (const auto &i : changed) {
648         std::lock_guard<std::recursive_mutex> guard(mutex_);
649         RunTaskQueue(taskOnChange_[i]);
650         KillCache(i);
651     }
652 }
653 
UpdatePtr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)654 void MDnsProtocolImpl::UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
655 {
656     const std::string *data = std::any_cast<std::string>(&rr.rdata);
657     if (data == nullptr) {
658         return;
659     }
660 
661     std::string name = rr.name;
662     if (browserMap_.find(name) == browserMap_.end()) {
663         return;
664     }
665     auto &results = browserMap_[name];
666     std::string srvName;
667     std::string srvType;
668     ExtractNameAndType(*data, srvName, srvType);
669     if (srvName.empty() || srvType.empty()) {
670         return;
671     }
672     auto res =
673         std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
674     if (res == results.end()) {
675         results.emplace_back(Result{
676             .serviceName = srvName,
677             .serviceType = srvType,
678             .state = State::ADD,
679         });
680     }
681     res = std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
682     if (res->serviceName != srvName || res->state == State::DEAD) {
683         res->state = State::REFRESH;
684         res->serviceName = srvName;
685     }
686     if (rr.ttl == 0) {
687         res->state = State::REMOVE;
688     }
689     if (res->state != State::LIVE && res->state != State::DEAD) {
690         changed.emplace(name);
691     }
692     res->ttl = rr.ttl;
693     res->refrehTime = MilliSecondsSinceEpoch();
694 }
695 
UpdateSrv(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)696 void MDnsProtocolImpl::UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
697 {
698     const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
699     if (srv == nullptr) {
700         return;
701     }
702     std::string name = rr.name;
703     if (cacheMap_.find(name) == cacheMap_.end()) {
704         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
705         cacheMap_[name].state = State::ADD;
706         cacheMap_[name].domain = srv->name;
707         cacheMap_[name].port = srv->port;
708     }
709     Result &result = cacheMap_[name];
710     if (result.domain != srv->name || result.port != srv->port || result.state == State::DEAD) {
711         if (result.state != State::ADD) {
712             result.state = State::REFRESH;
713         }
714         result.domain = srv->name;
715         result.port = srv->port;
716     }
717     if (rr.ttl == 0) {
718         result.state = State::REMOVE;
719     }
720     if (result.state != State::LIVE && result.state != State::DEAD) {
721         changed.emplace(name);
722     }
723     result.ttl = rr.ttl;
724     result.refrehTime = MilliSecondsSinceEpoch();
725 }
726 
UpdateTxt(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)727 void MDnsProtocolImpl::UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
728 {
729     const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
730     if (txt == nullptr) {
731         return;
732     }
733     std::string name = rr.name;
734     if (cacheMap_.find(name) == cacheMap_.end()) {
735         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
736         cacheMap_[name].state = State::ADD;
737         cacheMap_[name].txt = *txt;
738     }
739     Result &result = cacheMap_[name];
740     if (result.txt != *txt || result.state == State::DEAD) {
741         if (result.state != State::ADD) {
742             result.state = State::REFRESH;
743         }
744         result.txt = *txt;
745     }
746     if (rr.ttl == 0) {
747         result.state = State::REMOVE;
748     }
749     if (result.state != State::LIVE && result.state != State::DEAD) {
750         changed.emplace(name);
751     }
752     result.ttl = rr.ttl;
753     result.refrehTime = MilliSecondsSinceEpoch();
754 }
755 
UpdateAddr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)756 void MDnsProtocolImpl::UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
757 {
758     if (v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
759         return;
760     }
761     const std::string addr = AddrToString(rr.rdata);
762     bool v6rr = (rr.rtype == DNSProto::RRTYPE_AAAA);
763     if (addr.empty()) {
764         return;
765     }
766     std::string name = rr.name;
767     if (cacheMap_.find(name) == cacheMap_.end()) {
768         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
769         cacheMap_[name].state = State::ADD;
770         cacheMap_[name].ipv6 = v6rr;
771         cacheMap_[name].addr = addr;
772     }
773     Result &result = cacheMap_[name];
774     if (result.addr != addr || result.ipv6 != v6rr || result.state == State::DEAD) {
775         result.state = State::REFRESH;
776         result.addr = addr;
777         result.ipv6 = v6rr;
778     }
779     if (rr.ttl == 0) {
780         result.state = State::REMOVE;
781     }
782     if (result.state != State::LIVE && result.state != State::DEAD) {
783         changed.emplace(name);
784     }
785     result.ttl = rr.ttl;
786     result.refrehTime = MilliSecondsSinceEpoch();
787 }
788 
ProcessAnswerRecord(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)789 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
790 {
791     NETMGR_EXT_LOG_D("mdns_log ProcessAnswerRecord, type=[%{public}d]", rr.rtype);
792     std::lock_guard<std::recursive_mutex> guard(mutex_);
793     std::string name = rr.name;
794     if (cacheMap_.find(name) == cacheMap_.end() && browserMap_.find(name) == browserMap_.end() &&
795         srvMap_.find(name) != srvMap_.end()) {
796         return;
797     }
798     if (rr.rtype == DNSProto::RRTYPE_PTR) {
799         UpdatePtr(v6, rr, changed);
800     } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
801         UpdateSrv(v6, rr, changed);
802     } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
803         UpdateTxt(v6, rr, changed);
804     } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
805         UpdateAddr(v6, rr, changed);
806     } else {
807         NETMGR_EXT_LOG_D("mdns_log Unknown packet received, type=[%{public}d]", rr.rtype);
808     }
809 }
810 
GetHostDomain()811 std::string MDnsProtocolImpl::GetHostDomain()
812 {
813     if (config_.hostname.empty()) {
814         char buffer[MDNS_MAX_DOMAIN_LABEL];
815         if (gethostname(buffer, sizeof(buffer)) == 0) {
816             config_.hostname = buffer;
817             static auto uid = []() {
818                 std::random_device rd;
819                 return rd();
820             }();
821             config_.hostname += std::to_string(uid);
822         }
823     }
824     return Decorated(config_.hostname);
825 }
826 
AddTask(const Task & task,bool atonce)827 void MDnsProtocolImpl::AddTask(const Task &task, bool atonce)
828 {
829     {
830         std::lock_guard<std::recursive_mutex> guard(mutex_);
831         taskQueue_.emplace_back(task);
832     }
833     if (atonce) {
834         listener_.TriggerRefresh();
835     }
836 }
837 
ConvertResultToInfo(const MDnsProtocolImpl::Result & result)838 MDnsServiceInfo MDnsProtocolImpl::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
839 {
840     MDnsServiceInfo info;
841     info.name = result.serviceName;
842     info.type = result.serviceType;
843     if (!result.addr.empty()) {
844         info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
845     }
846     info.addr = result.addr;
847     info.port = result.port;
848     info.txtRecord = result.txt;
849     return info;
850 }
851 
IsCacheAvailable(const std::string & key)852 bool MDnsProtocolImpl::IsCacheAvailable(const std::string &key)
853 {
854     constexpr int64_t ms2S = 1000LL;
855     NETMGR_EXT_LOG_D("mdns_log IsCacheAvailable, ttl=[%{public}u]", cacheMap_[key].ttl);
856     return cacheMap_.find(key) != cacheMap_.end() &&
857            (ms2S * cacheMap_[key].ttl) > static_cast<uint32_t>(MilliSecondsSinceEpoch() - cacheMap_[key].refrehTime);
858 }
859 
IsDomainCacheAvailable(const std::string & key)860 bool MDnsProtocolImpl::IsDomainCacheAvailable(const std::string &key)
861 {
862     return IsCacheAvailable(key) && !cacheMap_[key].addr.empty();
863 }
864 
IsInstanceCacheAvailable(const std::string & key)865 bool MDnsProtocolImpl::IsInstanceCacheAvailable(const std::string &key)
866 {
867     return IsCacheAvailable(key) && !cacheMap_[key].domain.empty();
868 }
869 
IsBrowserAvailable(const std::string & key)870 bool MDnsProtocolImpl::IsBrowserAvailable(const std::string &key)
871 {
872     return browserMap_.find(key) != browserMap_.end() && !browserMap_[key].empty();
873 }
874 
AddEvent(const std::string & key,const Task & task)875 void MDnsProtocolImpl::AddEvent(const std::string &key, const Task &task)
876 {
877     std::lock_guard<std::recursive_mutex> guard(mutex_);
878     taskOnChange_[key].emplace_back(task);
879 }
880 
RunTaskQueue(std::list<Task> & queue)881 void MDnsProtocolImpl::RunTaskQueue(std::list<Task> &queue)
882 {
883     std::list<Task> tmp;
884     for (auto &&func : queue) {
885         if (!func()) {
886             tmp.emplace_back(func);
887         }
888     }
889     tmp.swap(queue);
890 }
891 
KillCache(const std::string & key)892 void MDnsProtocolImpl::KillCache(const std::string &key)
893 {
894     NETMGR_EXT_LOG_D("mdns_log KillCache");
895     if (IsBrowserAvailable(key) && browserMap_.find(key) != browserMap_.end()) {
896         for (auto it = browserMap_[key].begin(); it != browserMap_[key].end();) {
897             KillBrowseCache(key, it);
898         }
899     }
900     if (IsCacheAvailable(key)) {
901         std::lock_guard<std::recursive_mutex> guard(mutex_);
902         auto &elem = cacheMap_[key];
903         if (elem.state == State::REMOVE) {
904             elem.state = State::DEAD;
905             cacheMap_.erase(key);
906         } else if (elem.state == State::ADD || elem.state == State::REFRESH) {
907             elem.state = State::LIVE;
908         }
909     }
910 }
911 
KillBrowseCache(const std::string & key,std::vector<Result>::iterator & it)912 void MDnsProtocolImpl::KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)
913 {
914     NETMGR_EXT_LOG_D("mdns_log KillBrowseCache");
915     if (it->state == State::REMOVE) {
916         it->state = State::DEAD;
917         if (nameCbMap_.find(key) != nameCbMap_.end()) {
918             NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
919             nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
920         }
921         std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
922         cacheMap_.erase(fullName);
923         it = browserMap_[key].erase(it);
924     } else if (it->state == State::ADD || it->state == State::REFRESH) {
925         it->state = State::LIVE;
926         it++;
927     } else {
928         it++;
929     }
930 }
931 
StopCbMap(const std::string & serviceType)932 int32_t MDnsProtocolImpl::StopCbMap(const std::string &serviceType)
933 {
934     NETMGR_EXT_LOG_D("mdns_log StopCbMap");
935     std::lock_guard<std::recursive_mutex> guard(mutex_);
936     std::string name = Decorated(serviceType);
937     sptr<IDiscoveryCallback> cb = nullptr;
938     if (nameCbMap_.find(name) != nameCbMap_.end()) {
939         cb = nameCbMap_[name];
940         nameCbMap_.erase(name);
941     }
942     taskOnChange_.erase(name);
943     auto it = browserMap_.find(name);
944     if (it != browserMap_.end()) {
945         if (cb != nullptr) {
946             NETMGR_EXT_LOG_I("mdns_log StopCbMap res size:[%{public}zu]", it->second.size());
947             for (auto &&res : it->second) {
948                 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
949                 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
950             }
951         }
952         browserMap_.erase(name);
953     }
954     return NETMANAGER_SUCCESS;
955 }
956 } // namespace NetManagerStandard
957 } // namespace OHOS
958