1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_protocol_impl.h"
17
18 #include <arpa/inet.h>
19 #include <cstddef>
20 #include <iostream>
21 #include <random>
22 #include <sys/types.h>
23 #include <unistd.h>
24 #include <fcntl.h>
25
26 #include "mdns_manager.h"
27 #include "mdns_packet_parser.h"
28 #include "net_conn_client.h"
29 #include "netmgr_ext_log_wrapper.h"
30
31 #include "securec.h"
32
33 namespace OHOS {
34 namespace NetManagerStandard {
35
36 constexpr uint32_t DEFAULT_INTEVAL_MS = 2000;
37 constexpr uint32_t DEFAULT_LOST_MS = 10000;
38 constexpr uint32_t DEFAULT_TTL = 120;
39 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
40
41 constexpr int PHASE_PTR = 1;
42 constexpr int PHASE_SRV = 2;
43 constexpr int PHASE_DOMAIN = 3;
44
AddrToString(const std::any & addr)45 std::string AddrToString(const std::any &addr)
46 {
47 char buf[INET6_ADDRSTRLEN] = {0};
48 if (std::any_cast<in_addr>(&addr)) {
49 if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
50 return std::string{};
51 }
52 } else if (std::any_cast<in6_addr>(&addr)) {
53 if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
54 return std::string{};
55 }
56 }
57 return std::string(buf);
58 }
59
MilliSecondsSinceEpoch()60 int64_t MilliSecondsSinceEpoch()
61 {
62 return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
63 .count();
64 }
65
MDnsProtocolImpl()66 MDnsProtocolImpl::MDnsProtocolImpl()
67 {
68 Init();
69 }
70
Init()71 void MDnsProtocolImpl::Init()
72 {
73 NETMGR_EXT_LOG_D("mdns_log MDnsProtocolImpl init");
74 listener_.Stop();
75 listener_.CloseAllSocket();
76
77 if (config_.configAllIface) {
78 listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
79 } else {
80 listener_.OpenSocketForDefault(config_.ipv6Support);
81 }
82 listener_.SetReceiveHandler(
83 [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
84 listener_.SetFinishedHandler([this](int sock) {
85 std::lock_guard<std::recursive_mutex> guard(mutex_);
86 RunTaskQueue(taskQueue_);
87 });
88 listener_.Start();
89
90 taskQueue_.clear();
91 taskOnChange_.clear();
92 AddTask([this]() { return Browse(); }, false);
93 }
94
Browse()95 bool MDnsProtocolImpl::Browse()
96 {
97 if (lastRunTime != -1 && MilliSecondsSinceEpoch() - lastRunTime < DEFAULT_INTEVAL_MS) {
98 return false;
99 }
100 lastRunTime = MilliSecondsSinceEpoch();
101 std::lock_guard<std::recursive_mutex> guard(mutex_);
102 for (auto &&[key, res] : browserMap_) {
103 NETMGR_EXT_LOG_D("mdns_log Browse browserMap_ key[%{public}s] res.size[%{public}zu]", key.c_str(), res.size());
104 if (nameCbMap_.find(key) != nameCbMap_.end() &&
105 !MDnsManager::GetInstance().IsAvailableCallback(nameCbMap_[key])) {
106 continue;
107 }
108 handleOfflineService(key, res);
109 MDnsPayloadParser parser;
110 MDnsMessage msg{};
111 msg.questions.emplace_back(DNSProto::Question{
112 .name = key,
113 .qtype = DNSProto::RRTYPE_PTR,
114 .qclass = DNSProto::RRCLASS_IN,
115 });
116 listener_.MulticastAll(parser.ToBytes(msg));
117 }
118 return false;
119 }
120
ConnectControl(int32_t sockfd,sockaddr * serverAddr)121 int32_t MDnsProtocolImpl::ConnectControl(int32_t sockfd, sockaddr* serverAddr)
122 {
123 uint32_t flags = static_cast<uint32_t>(fcntl(sockfd, F_GETFL, 0));
124 fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
125 int32_t ret = connect(sockfd, serverAddr, sizeof(sockaddr));
126 if ((ret < 0) && (errno != EINPROGRESS)) {
127 NETMGR_EXT_LOG_E("connect error: %{public}d", errno);
128 return NETMANAGER_EXT_ERR_INTERNAL;
129 }
130 if (ret == 0) {
131 fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
132 NETMGR_EXT_LOG_I("connect success.");
133 return NETMANAGER_EXT_SUCCESS;
134 }
135
136 fd_set rset;
137 FD_ZERO(&rset);
138 FD_SET(sockfd, &rset);
139 fd_set wset = rset;
140 timeval tval {1, 0};
141 ret = select(sockfd + 1, &rset, &wset, NULL, &tval);
142 if (ret < 0) { // select error.
143 NETMGR_EXT_LOG_E("select error: %{public}d", errno);
144 return NETMANAGER_EXT_ERR_INTERNAL;
145 }
146 if (ret == 0) { // timeout
147 NETMGR_EXT_LOG_E("connect timeout...");
148 return NETMANAGER_EXT_ERR_INTERNAL;
149 }
150 if (!FD_ISSET(sockfd, &rset) && !FD_ISSET(sockfd, &wset)) {
151 NETMGR_EXT_LOG_E("select error: sockfd not set");
152 return NETMANAGER_EXT_ERR_INTERNAL;
153 }
154
155 int32_t result = NETMANAGER_EXT_ERR_INTERNAL;
156 socklen_t len = sizeof(result);
157 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &result, &len) < 0) {
158 NETMGR_EXT_LOG_E("getsockopt error: %{public}d", errno);
159 return NETMANAGER_EXT_ERR_INTERNAL;
160 }
161 if (result != 0) { // connect failed.
162 NETMGR_EXT_LOG_E("connect failed. error: %{public}d", result);
163 return NETMANAGER_EXT_ERR_INTERNAL;
164 }
165 fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
166 NETMGR_EXT_LOG_I("lost but connect success.");
167 return NETMANAGER_EXT_SUCCESS;
168 }
169
IsConnectivity(const std::string & ip,int32_t port)170 bool MDnsProtocolImpl::IsConnectivity(const std::string &ip, int32_t port)
171 {
172 if (ip.empty()) {
173 NETMGR_EXT_LOG_E("ip is empty");
174 return false;
175 }
176
177 int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
178 if (sockfd < 0) {
179 NETMGR_EXT_LOG_E("create socket error: %{public}d", errno);
180 return false;
181 }
182
183 struct sockaddr_in serverAddr;
184 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
185 NETMGR_EXT_LOG_E("memset_s serverAddr failed!");
186 close(sockfd);
187 return false;
188 }
189
190 serverAddr.sin_family = AF_INET;
191 serverAddr.sin_addr.s_addr = inet_addr(ip.c_str());
192 serverAddr.sin_port = htons(port);
193 if (ConnectControl(sockfd, (struct sockaddr*)&serverAddr) != NETMANAGER_EXT_SUCCESS) {
194 NETMGR_EXT_LOG_I("connect error: %{public}d", errno);
195 close(sockfd);
196 return false;
197 }
198
199 close(sockfd);
200 return true;
201 }
202
handleOfflineService(const std::string & key,std::vector<Result> & res)203 void MDnsProtocolImpl::handleOfflineService(const std::string &key, std::vector<Result> &res)
204 {
205 NETMGR_EXT_LOG_D("mdns_log handleOfflineService key:[%{public}s]", key.c_str());
206 for (auto it = res.begin(); it != res.end();) {
207 if (lastRunTime - it->refrehTime > DEFAULT_LOST_MS && it->state == State::LIVE) {
208 std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
209 if ((cacheMap_.find(fullName) != cacheMap_.end()) &&
210 IsConnectivity(cacheMap_[fullName].addr, cacheMap_[fullName].port)) {
211 it++;
212 continue;
213 }
214
215 it->state = State::DEAD;
216 if (nameCbMap_.find(key) != nameCbMap_.end() && nameCbMap_[key] != nullptr) {
217 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
218 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
219 }
220 it = res.erase(it);
221 cacheMap_.erase(fullName);
222 } else {
223 it++;
224 }
225 }
226 }
227
SetConfig(const MDnsConfig & config)228 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
229 {
230 config_ = config;
231 }
232
GetConfig() const233 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
234 {
235 return config_;
236 }
237
Decorated(const std::string & name) const238 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
239 {
240 return name + config_.topDomain;
241 }
242
Register(const Result & info)243 int32_t MDnsProtocolImpl::Register(const Result &info)
244 {
245 NETMGR_EXT_LOG_D("mdns_log Register");
246 if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
247 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
248 }
249 std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
250 if (!IsDomainValid(name)) {
251 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
252 }
253 {
254 std::lock_guard<std::recursive_mutex> guard(mutex_);
255 if (srvMap_.find(name) != srvMap_.end()) {
256 return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
257 }
258 srvMap_.emplace(name, info);
259 }
260 return Announce(info, false);
261 }
262
UnRegister(const std::string & key)263 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
264 {
265 NETMGR_EXT_LOG_D("mdns_log UnRegister");
266 std::string name = Decorated(key);
267 std::lock_guard<std::recursive_mutex> guard(mutex_);
268 if (srvMap_.find(name) != srvMap_.end()) {
269 Announce(srvMap_[name], true);
270 srvMap_.erase(name);
271 return NETMANAGER_EXT_SUCCESS;
272 }
273 return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
274 }
275
DiscoveryFromCache(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)276 bool MDnsProtocolImpl::DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
277 {
278 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache");
279 std::string name = Decorated(serviceType);
280 std::lock_guard<std::recursive_mutex> guard(mutex_);
281 if (!IsBrowserAvailable(name)) {
282 return false;
283 }
284
285 if (browserMap_.find(name) == browserMap_.end()) {
286 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache browserMap_ not find name");
287 return false;
288 }
289
290 for (auto &res : browserMap_[name]) {
291 if (res.state == State::REMOVE || res.state == State::DEAD) {
292 continue;
293 }
294 AddTask([cb, info = ConvertResultToInfo(res)]() {
295 NETMGR_EXT_LOG_W("mdns_log DiscoveryFromCache ConvertResultToInfo HandleServiceFound");
296 if (MDnsManager::GetInstance().IsAvailableCallback(cb)) {
297 cb->HandleServiceFound(info, NETMANAGER_EXT_SUCCESS);
298 }
299 return true;
300 });
301 }
302 return true;
303 }
304
DiscoveryFromNet(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)305 bool MDnsProtocolImpl::DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
306 {
307 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromNet");
308 std::string name = Decorated(serviceType);
309 std::lock_guard<std::recursive_mutex> guard(mutex_);
310 browserMap_.insert({name, std::vector<Result>{}});
311 nameCbMap_[name] = cb;
312 // key is serviceTYpe
313 AddEvent(name, [this, name, cb]() {
314 std::lock_guard<std::recursive_mutex> guard(mutex_);
315 if (!IsBrowserAvailable(name)) {
316 return false;
317 }
318 if (!MDnsManager::GetInstance().IsAvailableCallback(cb)) {
319 return true;
320 }
321 for (auto &res : browserMap_[name]) {
322 std::string fullName = Decorated(res.serviceName + MDNS_DOMAIN_SPLITER_STR + res.serviceType);
323 NETMGR_EXT_LOG_W("mdns_log DiscoveryFromNet name:[%{public}s] fullName:[%{public}s]", name.c_str(),
324 fullName.c_str());
325 if (cacheMap_.find(fullName) == cacheMap_.end() ||
326 (res.state == State::ADD || res.state == State::REFRESH)) {
327 NETMGR_EXT_LOG_W("mdns_log HandleServiceFound");
328 cb->HandleServiceFound(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
329 res.state = State::LIVE;
330 }
331 if (res.state == State::REMOVE) {
332 res.state = State::DEAD;
333 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
334 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
335 if (cacheMap_.find(fullName) != cacheMap_.end()) {
336 res.state = State::ADD;
337 cacheMap_.erase(fullName);
338 }
339 }
340 }
341 return false;
342 });
343
344 AddTask([=]() {
345 MDnsPayloadParser parser;
346 MDnsMessage msg{};
347 msg.questions.emplace_back(DNSProto::Question{
348 .name = name,
349 .qtype = DNSProto::RRTYPE_PTR,
350 .qclass = DNSProto::RRCLASS_IN,
351 });
352 listener_.MulticastAll(parser.ToBytes(msg));
353 return true;
354 }, false);
355 return true;
356 }
357
Discovery(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)358 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
359 {
360 NETMGR_EXT_LOG_D("mdns_log Discovery");
361 DiscoveryFromCache(serviceType, cb);
362 DiscoveryFromNet(serviceType, cb);
363 return NETMANAGER_EXT_SUCCESS;
364 }
365
ResolveInstanceFromCache(const std::string & name,const sptr<IResolveCallback> & cb)366 bool MDnsProtocolImpl::ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)
367 {
368 NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromCache");
369 std::lock_guard<std::recursive_mutex> guard(mutex_);
370 if (!IsInstanceCacheAvailable(name)) {
371 NETMGR_EXT_LOG_W("mdns_log ResolveInstanceFromCache cacheMap_ has no element [%{public}s]", name.c_str());
372 return false;
373 }
374
375 NETMGR_EXT_LOG_I("mdns_log rr.name : [%{public}s]", name.c_str());
376 Result r = cacheMap_[name];
377 if (IsDomainCacheAvailable(r.domain)) {
378 r.ipv6 = cacheMap_[r.domain].ipv6;
379 r.addr = cacheMap_[r.domain].addr;
380
381 NETMGR_EXT_LOG_D("mdns_log Add Task DomainCache Available, [%{public}s]", r.domain.c_str());
382 AddTask([cb, info = ConvertResultToInfo(r)]() {
383 if (nullptr != cb) {
384 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
385 }
386 return true;
387 });
388 } else {
389 ResolveFromNet(r.domain, nullptr);
390 NETMGR_EXT_LOG_D("mdns_log Add Event DomainCache UnAvailable, [%{public}s]", r.domain.c_str());
391 AddEvent(r.domain, [this, cb, r]() mutable {
392 if (!IsDomainCacheAvailable(r.domain)) {
393 return false;
394 }
395 r.ipv6 = cacheMap_[r.domain].ipv6;
396 r.addr = cacheMap_[r.domain].addr;
397 if (nullptr != cb) {
398 cb->HandleResolveResult(ConvertResultToInfo(r), NETMANAGER_EXT_SUCCESS);
399 }
400 return true;
401 });
402 }
403 return true;
404 }
405
ResolveInstanceFromNet(const std::string & name,const sptr<IResolveCallback> & cb)406 bool MDnsProtocolImpl::ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)
407 {
408 NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromNet");
409 {
410 std::lock_guard<std::recursive_mutex> guard(mutex_);
411 cacheMap_[name].state = State::ADD;
412 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
413 }
414 MDnsPayloadParser parser;
415 MDnsMessage msg{};
416 msg.questions.emplace_back(DNSProto::Question{
417 .name = name,
418 .qtype = DNSProto::RRTYPE_SRV,
419 .qclass = DNSProto::RRCLASS_IN,
420 });
421 msg.questions.emplace_back(DNSProto::Question{
422 .name = name,
423 .qtype = DNSProto::RRTYPE_TXT,
424 .qclass = DNSProto::RRCLASS_IN,
425 });
426 msg.header.qdcount = msg.questions.size();
427 AddEvent(name, [this, name, cb]() { return ResolveInstanceFromCache(name, cb); });
428 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
429 return size > 0;
430 }
431
ResolveFromCache(const std::string & domain,const sptr<IResolveCallback> & cb)432 bool MDnsProtocolImpl::ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)
433 {
434 NETMGR_EXT_LOG_D("mdns_log ResolveFromCache");
435 std::lock_guard<std::recursive_mutex> guard(mutex_);
436 if (!IsDomainCacheAvailable(domain)) {
437 return false;
438 }
439 AddTask([this, cb, info = ConvertResultToInfo(cacheMap_[domain])]() {
440 if (nullptr != cb) {
441 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
442 }
443 return true;
444 });
445 return true;
446 }
447
ResolveFromNet(const std::string & domain,const sptr<IResolveCallback> & cb)448 bool MDnsProtocolImpl::ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)
449 {
450 NETMGR_EXT_LOG_D("mdns_log ResolveFromNet");
451 {
452 std::lock_guard<std::recursive_mutex> guard(mutex_);
453 cacheMap_[domain];
454 cacheMap_[domain].domain = domain;
455 }
456 MDnsPayloadParser parser;
457 MDnsMessage msg{};
458 msg.questions.emplace_back(DNSProto::Question{
459 .name = domain,
460 .qtype = DNSProto::RRTYPE_A,
461 .qclass = DNSProto::RRCLASS_IN,
462 });
463 msg.questions.emplace_back(DNSProto::Question{
464 .name = domain,
465 .qtype = DNSProto::RRTYPE_AAAA,
466 .qclass = DNSProto::RRCLASS_IN,
467 });
468 // key is serviceName
469 AddEvent(domain, [this, cb, domain]() { return ResolveFromCache(domain, cb); });
470 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
471 return size > 0;
472 }
473
ResolveInstance(const std::string & instance,const sptr<IResolveCallback> & cb)474 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)
475 {
476 NETMGR_EXT_LOG_D("mdns_log execute ResolveInstance");
477 if (!IsInstanceValid(instance)) {
478 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
479 }
480 std::string name = Decorated(instance);
481 if (!IsDomainValid(name)) {
482 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
483 }
484 if (ResolveInstanceFromCache(name, cb)) {
485 return NETMANAGER_EXT_SUCCESS;
486 }
487 return ResolveInstanceFromNet(name, cb) ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
488 }
489
Announce(const Result & info,bool off)490 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
491 {
492 NETMGR_EXT_LOG_I("mdns_log Announce message");
493 MDnsMessage response{};
494 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
495 std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
496 response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
497 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
498 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
499 .ttl = off ? 0U : DEFAULT_TTL,
500 .rdata = name});
501 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
502 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
503 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
504 .ttl = off ? 0U : DEFAULT_TTL,
505 .rdata = DNSProto::RDataSrv{
506 .priority = 0,
507 .weight = 0,
508 .port = static_cast<uint16_t>(info.port),
509 .name = GetHostDomain(),
510 }});
511 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
512 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
513 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
514 .ttl = off ? 0U : DEFAULT_TTL,
515 .rdata = info.txt});
516 MDnsPayloadParser parser;
517 ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
518 return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
519 }
520
ReceivePacket(int sock,const MDnsPayload & payload)521 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
522 {
523 if (payload.size() == 0) {
524 return;
525 }
526 MDnsPayloadParser parser;
527 MDnsMessage msg = parser.FromBytes(payload);
528 if (parser.GetError() != 0) {
529 NETMGR_EXT_LOG_E("parser payload failed");
530 return;
531 }
532 if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
533 ProcessQuestion(sock, msg);
534 } else {
535 ProcessAnswer(sock, msg);
536 }
537 }
538
AppendRecord(std::vector<DNSProto::ResourceRecord> & rrlist,DNSProto::RRType type,const std::string & name,const std::any & rdata)539 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
540 const std::string &name, const std::any &rdata)
541 {
542 rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
543 .rtype = static_cast<uint16_t>(type),
544 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
545 .ttl = DEFAULT_TTL,
546 .rdata = rdata});
547 }
548
ProcessQuestion(int sock,const MDnsMessage & msg)549 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
550 {
551 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
552 if (saddrIf == nullptr) {
553 NETMGR_EXT_LOG_W("mdns_log ProcessQuestion saddrIf is null");
554 return;
555 }
556 std::any anyAddr;
557 DNSProto::RRType anyAddrType;
558 if (saddrIf->sa_family == AF_INET6) {
559 anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
560 anyAddrType = DNSProto::RRTYPE_AAAA;
561 } else {
562 anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
563 anyAddrType = DNSProto::RRTYPE_A;
564 }
565 int phase = 0;
566 MDnsMessage response{};
567 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
568 for (size_t i = 0; i < msg.header.qdcount; ++i) {
569 ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
570 }
571 if (phase < PHASE_DOMAIN) {
572 AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
573 }
574
575 if (phase != 0 && response.answers.size() > 0) {
576 listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
577 }
578 }
579
ProcessQuestionRecord(const std::any & anyAddr,const DNSProto::RRType & anyAddrType,const DNSProto::Question & qu,int & phase,MDnsMessage & response)580 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
581 const DNSProto::Question &qu, int &phase, MDnsMessage &response)
582 {
583 NETMGR_EXT_LOG_D("mdns_log ProcessQuestionRecord");
584 std::lock_guard<std::recursive_mutex> guard(mutex_);
585 std::string name = qu.name;
586 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
587 std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
588 if (EndsWith(elem.first, name)) {
589 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
590 AppendRecord(response.additional, DNSProto::RRTYPE_SRV, elem.first,
591 DNSProto::RDataSrv{
592 .priority = 0,
593 .weight = 0,
594 .port = static_cast<uint16_t>(elem.second.port),
595 .name = GetHostDomain(),
596 });
597 AppendRecord(response.additional, DNSProto::RRTYPE_TXT, elem.first, elem.second.txt);
598 }
599 });
600 phase = std::max(phase, PHASE_PTR);
601 }
602 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
603 auto iter = srvMap_.find(name);
604 if (iter == srvMap_.end()) {
605 return;
606 }
607 AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
608 DNSProto::RDataSrv{
609 .priority = 0,
610 .weight = 0,
611 .port = static_cast<uint16_t>(iter->second.port),
612 .name = GetHostDomain(),
613 });
614 phase = std::max(phase, PHASE_SRV);
615 }
616 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
617 auto iter = srvMap_.find(name);
618 if (iter == srvMap_.end()) {
619 return;
620 }
621 AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
622 phase = std::max(phase, PHASE_SRV);
623 }
624 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
625 if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
626 return;
627 }
628 AppendRecord(response.answers, anyAddrType, name, anyAddr);
629 phase = std::max(phase, PHASE_DOMAIN);
630 }
631 }
632
ProcessAnswer(int sock,const MDnsMessage & msg)633 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
634 {
635 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
636 if (saddrIf == nullptr) {
637 return;
638 }
639 bool v6 = (saddrIf->sa_family == AF_INET6);
640 std::set<std::string> changed;
641 for (const auto &answer : msg.answers) {
642 ProcessAnswerRecord(v6, answer, changed);
643 }
644 for (const auto &i : msg.additional) {
645 ProcessAnswerRecord(v6, i, changed);
646 }
647 for (const auto &i : changed) {
648 std::lock_guard<std::recursive_mutex> guard(mutex_);
649 RunTaskQueue(taskOnChange_[i]);
650 KillCache(i);
651 }
652 }
653
UpdatePtr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)654 void MDnsProtocolImpl::UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
655 {
656 const std::string *data = std::any_cast<std::string>(&rr.rdata);
657 if (data == nullptr) {
658 return;
659 }
660
661 std::string name = rr.name;
662 if (browserMap_.find(name) == browserMap_.end()) {
663 return;
664 }
665 auto &results = browserMap_[name];
666 std::string srvName;
667 std::string srvType;
668 ExtractNameAndType(*data, srvName, srvType);
669 if (srvName.empty() || srvType.empty()) {
670 return;
671 }
672 auto res =
673 std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
674 if (res == results.end()) {
675 results.emplace_back(Result{
676 .serviceName = srvName,
677 .serviceType = srvType,
678 .state = State::ADD,
679 });
680 }
681 res = std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
682 if (res->serviceName != srvName || res->state == State::DEAD) {
683 res->state = State::REFRESH;
684 res->serviceName = srvName;
685 }
686 if (rr.ttl == 0) {
687 res->state = State::REMOVE;
688 }
689 if (res->state != State::LIVE && res->state != State::DEAD) {
690 changed.emplace(name);
691 }
692 res->ttl = rr.ttl;
693 res->refrehTime = MilliSecondsSinceEpoch();
694 }
695
UpdateSrv(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)696 void MDnsProtocolImpl::UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
697 {
698 const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
699 if (srv == nullptr) {
700 return;
701 }
702 std::string name = rr.name;
703 if (cacheMap_.find(name) == cacheMap_.end()) {
704 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
705 cacheMap_[name].state = State::ADD;
706 cacheMap_[name].domain = srv->name;
707 cacheMap_[name].port = srv->port;
708 }
709 Result &result = cacheMap_[name];
710 if (result.domain != srv->name || result.port != srv->port || result.state == State::DEAD) {
711 if (result.state != State::ADD) {
712 result.state = State::REFRESH;
713 }
714 result.domain = srv->name;
715 result.port = srv->port;
716 }
717 if (rr.ttl == 0) {
718 result.state = State::REMOVE;
719 }
720 if (result.state != State::LIVE && result.state != State::DEAD) {
721 changed.emplace(name);
722 }
723 result.ttl = rr.ttl;
724 result.refrehTime = MilliSecondsSinceEpoch();
725 }
726
UpdateTxt(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)727 void MDnsProtocolImpl::UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
728 {
729 const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
730 if (txt == nullptr) {
731 return;
732 }
733 std::string name = rr.name;
734 if (cacheMap_.find(name) == cacheMap_.end()) {
735 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
736 cacheMap_[name].state = State::ADD;
737 cacheMap_[name].txt = *txt;
738 }
739 Result &result = cacheMap_[name];
740 if (result.txt != *txt || result.state == State::DEAD) {
741 if (result.state != State::ADD) {
742 result.state = State::REFRESH;
743 }
744 result.txt = *txt;
745 }
746 if (rr.ttl == 0) {
747 result.state = State::REMOVE;
748 }
749 if (result.state != State::LIVE && result.state != State::DEAD) {
750 changed.emplace(name);
751 }
752 result.ttl = rr.ttl;
753 result.refrehTime = MilliSecondsSinceEpoch();
754 }
755
UpdateAddr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)756 void MDnsProtocolImpl::UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
757 {
758 if (v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
759 return;
760 }
761 const std::string addr = AddrToString(rr.rdata);
762 bool v6rr = (rr.rtype == DNSProto::RRTYPE_AAAA);
763 if (addr.empty()) {
764 return;
765 }
766 std::string name = rr.name;
767 if (cacheMap_.find(name) == cacheMap_.end()) {
768 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
769 cacheMap_[name].state = State::ADD;
770 cacheMap_[name].ipv6 = v6rr;
771 cacheMap_[name].addr = addr;
772 }
773 Result &result = cacheMap_[name];
774 if (result.addr != addr || result.ipv6 != v6rr || result.state == State::DEAD) {
775 result.state = State::REFRESH;
776 result.addr = addr;
777 result.ipv6 = v6rr;
778 }
779 if (rr.ttl == 0) {
780 result.state = State::REMOVE;
781 }
782 if (result.state != State::LIVE && result.state != State::DEAD) {
783 changed.emplace(name);
784 }
785 result.ttl = rr.ttl;
786 result.refrehTime = MilliSecondsSinceEpoch();
787 }
788
ProcessAnswerRecord(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)789 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
790 {
791 NETMGR_EXT_LOG_D("mdns_log ProcessAnswerRecord, type=[%{public}d]", rr.rtype);
792 std::lock_guard<std::recursive_mutex> guard(mutex_);
793 std::string name = rr.name;
794 if (cacheMap_.find(name) == cacheMap_.end() && browserMap_.find(name) == browserMap_.end() &&
795 srvMap_.find(name) != srvMap_.end()) {
796 return;
797 }
798 if (rr.rtype == DNSProto::RRTYPE_PTR) {
799 UpdatePtr(v6, rr, changed);
800 } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
801 UpdateSrv(v6, rr, changed);
802 } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
803 UpdateTxt(v6, rr, changed);
804 } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
805 UpdateAddr(v6, rr, changed);
806 } else {
807 NETMGR_EXT_LOG_D("mdns_log Unknown packet received, type=[%{public}d]", rr.rtype);
808 }
809 }
810
GetHostDomain()811 std::string MDnsProtocolImpl::GetHostDomain()
812 {
813 if (config_.hostname.empty()) {
814 char buffer[MDNS_MAX_DOMAIN_LABEL];
815 if (gethostname(buffer, sizeof(buffer)) == 0) {
816 config_.hostname = buffer;
817 static auto uid = []() {
818 std::random_device rd;
819 return rd();
820 }();
821 config_.hostname += std::to_string(uid);
822 }
823 }
824 return Decorated(config_.hostname);
825 }
826
AddTask(const Task & task,bool atonce)827 void MDnsProtocolImpl::AddTask(const Task &task, bool atonce)
828 {
829 {
830 std::lock_guard<std::recursive_mutex> guard(mutex_);
831 taskQueue_.emplace_back(task);
832 }
833 if (atonce) {
834 listener_.TriggerRefresh();
835 }
836 }
837
ConvertResultToInfo(const MDnsProtocolImpl::Result & result)838 MDnsServiceInfo MDnsProtocolImpl::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
839 {
840 MDnsServiceInfo info;
841 info.name = result.serviceName;
842 info.type = result.serviceType;
843 if (!result.addr.empty()) {
844 info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
845 }
846 info.addr = result.addr;
847 info.port = result.port;
848 info.txtRecord = result.txt;
849 return info;
850 }
851
IsCacheAvailable(const std::string & key)852 bool MDnsProtocolImpl::IsCacheAvailable(const std::string &key)
853 {
854 constexpr int64_t ms2S = 1000LL;
855 NETMGR_EXT_LOG_D("mdns_log IsCacheAvailable, ttl=[%{public}u]", cacheMap_[key].ttl);
856 return cacheMap_.find(key) != cacheMap_.end() &&
857 (ms2S * cacheMap_[key].ttl) > static_cast<uint32_t>(MilliSecondsSinceEpoch() - cacheMap_[key].refrehTime);
858 }
859
IsDomainCacheAvailable(const std::string & key)860 bool MDnsProtocolImpl::IsDomainCacheAvailable(const std::string &key)
861 {
862 return IsCacheAvailable(key) && !cacheMap_[key].addr.empty();
863 }
864
IsInstanceCacheAvailable(const std::string & key)865 bool MDnsProtocolImpl::IsInstanceCacheAvailable(const std::string &key)
866 {
867 return IsCacheAvailable(key) && !cacheMap_[key].domain.empty();
868 }
869
IsBrowserAvailable(const std::string & key)870 bool MDnsProtocolImpl::IsBrowserAvailable(const std::string &key)
871 {
872 return browserMap_.find(key) != browserMap_.end() && !browserMap_[key].empty();
873 }
874
AddEvent(const std::string & key,const Task & task)875 void MDnsProtocolImpl::AddEvent(const std::string &key, const Task &task)
876 {
877 std::lock_guard<std::recursive_mutex> guard(mutex_);
878 taskOnChange_[key].emplace_back(task);
879 }
880
RunTaskQueue(std::list<Task> & queue)881 void MDnsProtocolImpl::RunTaskQueue(std::list<Task> &queue)
882 {
883 std::list<Task> tmp;
884 for (auto &&func : queue) {
885 if (!func()) {
886 tmp.emplace_back(func);
887 }
888 }
889 tmp.swap(queue);
890 }
891
KillCache(const std::string & key)892 void MDnsProtocolImpl::KillCache(const std::string &key)
893 {
894 NETMGR_EXT_LOG_D("mdns_log KillCache");
895 if (IsBrowserAvailable(key) && browserMap_.find(key) != browserMap_.end()) {
896 for (auto it = browserMap_[key].begin(); it != browserMap_[key].end();) {
897 KillBrowseCache(key, it);
898 }
899 }
900 if (IsCacheAvailable(key)) {
901 std::lock_guard<std::recursive_mutex> guard(mutex_);
902 auto &elem = cacheMap_[key];
903 if (elem.state == State::REMOVE) {
904 elem.state = State::DEAD;
905 cacheMap_.erase(key);
906 } else if (elem.state == State::ADD || elem.state == State::REFRESH) {
907 elem.state = State::LIVE;
908 }
909 }
910 }
911
KillBrowseCache(const std::string & key,std::vector<Result>::iterator & it)912 void MDnsProtocolImpl::KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)
913 {
914 NETMGR_EXT_LOG_D("mdns_log KillBrowseCache");
915 if (it->state == State::REMOVE) {
916 it->state = State::DEAD;
917 if (nameCbMap_.find(key) != nameCbMap_.end()) {
918 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
919 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
920 }
921 std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
922 cacheMap_.erase(fullName);
923 it = browserMap_[key].erase(it);
924 } else if (it->state == State::ADD || it->state == State::REFRESH) {
925 it->state = State::LIVE;
926 it++;
927 } else {
928 it++;
929 }
930 }
931
StopCbMap(const std::string & serviceType)932 int32_t MDnsProtocolImpl::StopCbMap(const std::string &serviceType)
933 {
934 NETMGR_EXT_LOG_D("mdns_log StopCbMap");
935 std::lock_guard<std::recursive_mutex> guard(mutex_);
936 std::string name = Decorated(serviceType);
937 sptr<IDiscoveryCallback> cb = nullptr;
938 if (nameCbMap_.find(name) != nameCbMap_.end()) {
939 cb = nameCbMap_[name];
940 nameCbMap_.erase(name);
941 }
942 taskOnChange_.erase(name);
943 auto it = browserMap_.find(name);
944 if (it != browserMap_.end()) {
945 if (cb != nullptr) {
946 NETMGR_EXT_LOG_I("mdns_log StopCbMap res size:[%{public}zu]", it->second.size());
947 for (auto &&res : it->second) {
948 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
949 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
950 }
951 }
952 browserMap_.erase(name);
953 }
954 return NETMANAGER_SUCCESS;
955 }
956 } // namespace NetManagerStandard
957 } // namespace OHOS
958