1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dns_param_cache.h"
17 
18 #include <algorithm>
19 
20 #include "netmanager_base_common_utils.h"
21 #ifdef FEATURE_NET_FIREWALL_ENABLE
22 #include "netfirewall_parcel.h"
23 #include <ctime>
24 #endif
25 
26 namespace OHOS::nmd {
27 using namespace OHOS::NetManagerStandard::CommonUtils;
28 namespace {
GetVectorData(const std::vector<std::string> & data,std::string & result)29 void GetVectorData(const std::vector<std::string> &data, std::string &result)
30 {
31     result.append("{ ");
32     std::for_each(data.begin(), data.end(), [&result](const auto &str) { result.append(ToAnonymousIp(str) + ", "); });
33     result.append("}\n");
34 }
35 constexpr int RES_TIMEOUT = 5000;    // min. milliseconds between retries
36 constexpr int RES_DEFAULT_RETRY = 2; // Default
37 
38 #ifdef FEATURE_NET_FIREWALL_ENABLE
39 constexpr int32_t USER_ID_DIVIDOR  = 200000;
40 #endif
41 } // namespace
42 
DnsParamCache()43 DnsParamCache::DnsParamCache() : defaultNetId_(0) {}
44 
GetInstance()45 DnsParamCache &DnsParamCache::GetInstance()
46 {
47     static DnsParamCache instance;
48     return instance;
49 }
50 
SelectNameservers(const std::vector<std::string> & servers)51 std::vector<std::string> DnsParamCache::SelectNameservers(const std::vector<std::string> &servers)
52 {
53     std::vector<std::string> res = servers;
54     if (res.size() > MAX_SERVER_NUM) {
55         res.resize(MAX_SERVER_NUM);
56     }
57     return res;
58 }
59 
CreateCacheForNet(uint16_t netId)60 int32_t DnsParamCache::CreateCacheForNet(uint16_t netId)
61 {
62     NETNATIVE_LOG_D("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
63     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
64     auto it = serverConfigMap_.find(netId);
65     if (it != serverConfigMap_.end()) {
66         NETNATIVE_LOGE("DnsParamCache::CreateCacheForNet, netid already exist, no need to create");
67         return -EEXIST;
68     }
69     serverConfigMap_[netId].SetNetId(netId);
70     return 0;
71 }
72 
DestroyNetworkCache(uint16_t netId)73 int32_t DnsParamCache::DestroyNetworkCache(uint16_t netId)
74 {
75     NETNATIVE_LOG_D("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
76     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
77     auto it = serverConfigMap_.find(netId);
78     if (it == serverConfigMap_.end()) {
79         return -ENOENT;
80     }
81     serverConfigMap_.erase(it);
82     if (defaultNetId_ == netId) {
83         defaultNetId_ = 0;
84     }
85     return 0;
86 }
87 
SetResolverConfig(uint16_t netId,uint16_t baseTimeoutMsec,uint8_t retryCount,const std::vector<std::string> & servers,const std::vector<std::string> & domains)88 int32_t DnsParamCache::SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount,
89                                          const std::vector<std::string> &servers,
90                                          const std::vector<std::string> &domains)
91 {
92     std::vector<std::string> nameservers = SelectNameservers(servers);
93     NETNATIVE_LOG_D("DnsParamCache::SetResolverConfig, netid:%{public}d, numServers:%{public}d,", netId,
94                     static_cast<int>(nameservers.size()));
95 
96     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
97 
98     // select_domains
99     auto it = serverConfigMap_.find(netId);
100     if (it == serverConfigMap_.end()) {
101         NETNATIVE_LOGE("DnsParamCache::SetResolverConfig failed, netid is non-existent");
102         return -ENOENT;
103     }
104 
105     auto oldDnsServers = it->second.GetServers();
106     std::sort(oldDnsServers.begin(), oldDnsServers.end());
107 
108     auto newDnsServers = servers;
109     std::sort(newDnsServers.begin(), newDnsServers.end());
110 
111     if (oldDnsServers != newDnsServers) {
112         it->second.GetCache().Clear();
113     }
114 
115     it->second.SetNetId(netId);
116     it->second.SetServers(servers);
117     it->second.SetDomains(domains);
118     if (retryCount == 0) {
119         it->second.SetRetryCount(RES_DEFAULT_RETRY);
120     } else {
121         it->second.SetRetryCount(retryCount);
122     }
123     if (baseTimeoutMsec == 0) {
124         it->second.SetTimeoutMsec(RES_TIMEOUT);
125     } else {
126         it->second.SetTimeoutMsec(baseTimeoutMsec);
127     }
128     return 0;
129 }
130 
SetDefaultNetwork(uint16_t netId)131 void DnsParamCache::SetDefaultNetwork(uint16_t netId)
132 {
133     defaultNetId_ = netId;
134 }
135 
EnableIpv6(uint16_t netId)136 void DnsParamCache::EnableIpv6(uint16_t netId)
137 {
138     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
139     auto it = serverConfigMap_.find(netId);
140     if (it == serverConfigMap_.end()) {
141         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
142         return;
143     }
144 
145     it->second.EnableIpv6();
146 }
147 
IsIpv6Enable(uint16_t netId)148 bool DnsParamCache::IsIpv6Enable(uint16_t netId)
149 {
150     if (netId == 0) {
151         netId = defaultNetId_;
152     }
153 
154     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
155     auto it = serverConfigMap_.find(netId);
156     if (it == serverConfigMap_.end()) {
157         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
158         return false;
159     }
160 
161     return it->second.IsIpv6Enable();
162 }
163 
GetResolverConfig(uint16_t netId,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)164 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, std::vector<std::string> &servers,
165                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
166                                          uint8_t &retryCount)
167 {
168     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig no uid");
169     if (netId == 0) {
170         netId = defaultNetId_;
171         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
172     }
173 
174     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
175     auto it = serverConfigMap_.find(netId);
176     if (it == serverConfigMap_.end()) {
177         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
178         return -ENOENT;
179     }
180 
181     servers = it->second.GetServers();
182 #ifdef FEATURE_NET_FIREWALL_ENABLE
183     std::vector<std::string> dns;
184     if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
185         DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
186         servers.assign(dns.begin(), dns.end());
187     }
188 #endif
189     domains = it->second.GetDomains();
190     baseTimeoutMsec = it->second.GetTimeoutMsec();
191     retryCount = it->second.GetRetryCount();
192 
193     return 0;
194 }
195 
GetResolverConfig(uint16_t netId,uint32_t uid,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)196 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers,
197                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
198                                          uint8_t &retryCount)
199 {
200     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig has uid");
201     if (netId == 0) {
202         netId = defaultNetId_;
203         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
204     }
205 
206     {
207         std::lock_guard<ffrt::mutex> guard(cacheMutex_);
208         for (auto mem : vpnUidRanges_) {
209             if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
210                 NETNATIVE_LOG_D("is vpn hap");
211                 auto it = serverConfigMap_.find(vpnNetId_);
212                 if (it == serverConfigMap_.end()) {
213                     NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
214                     break;
215                 }
216                 servers = it->second.GetServers();
217 #ifdef FEATURE_NET_FIREWALL_ENABLE
218                 std::vector<std::string> dns;
219                 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
220                     DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
221                     servers.assign(dns.begin(), dns.end());
222                 }
223 #endif
224                 domains = it->second.GetDomains();
225                 baseTimeoutMsec = it->second.GetTimeoutMsec();
226                 retryCount = it->second.GetRetryCount();
227                 return 0;
228             }
229         }
230     }
231     return GetResolverConfig(netId, servers, domains, baseTimeoutMsec, retryCount);
232 }
233 
GetDefaultNetwork() const234 int32_t DnsParamCache::GetDefaultNetwork() const
235 {
236     return defaultNetId_;
237 }
238 
SetDnsCache(uint16_t netId,const std::string & hostName,const AddrInfo & addrInfo)239 void DnsParamCache::SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)
240 {
241     if (netId == 0) {
242         netId = defaultNetId_;
243     }
244     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
245 #ifdef FEATURE_NET_FIREWALL_ENABLE
246     int32_t appUid = GetCallingUid();
247     if (IsInterceptDomain(appUid, hostName)) {
248         DNS_CONFIG_PRINT("SetDnsCache failed: domain was Intercepted: %{public}s,", hostName.c_str());
249         return;
250     }
251 #endif
252     auto it = serverConfigMap_.find(netId);
253     if (it == serverConfigMap_.end()) {
254         DNS_CONFIG_PRINT("SetDnsCache failed: netid is not have netid:%{public}d,", netId);
255         return;
256     }
257 
258     it->second.GetCache().Put(hostName, addrInfo);
259 }
260 
GetDnsCache(uint16_t netId,const std::string & hostName)261 std::vector<AddrInfo> DnsParamCache::GetDnsCache(uint16_t netId, const std::string &hostName)
262 {
263     if (netId == 0) {
264         netId = defaultNetId_;
265     }
266 
267     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
268 #ifdef FEATURE_NET_FIREWALL_ENABLE
269     int32_t appUid = GetCallingUid();
270     if (IsInterceptDomain(appUid, hostName)) {
271         NotifyDomianIntercept(appUid, hostName);
272         AddrInfo fakeAddr = { 0 };
273         fakeAddr.aiFamily = AF_UNSPEC;
274         fakeAddr.aiAddr.sin.sin_family = AF_UNSPEC;
275         fakeAddr.aiAddr.sin.sin_addr.s_addr = INADDR_NONE;
276         fakeAddr.aiAddrLen = sizeof(struct sockaddr_in);
277         return { fakeAddr };
278     }
279 #endif
280 
281     auto it = serverConfigMap_.find(netId);
282     if (it == serverConfigMap_.end()) {
283         DNS_CONFIG_PRINT("GetDnsCache failed: netid is not have netid:%{public}d,", netId);
284         return {};
285     }
286 
287     return it->second.GetCache().Get(hostName);
288 }
289 
SetCacheDelayed(uint16_t netId,const std::string & hostName)290 void DnsParamCache::SetCacheDelayed(uint16_t netId, const std::string &hostName)
291 {
292     if (netId == 0) {
293         netId = defaultNetId_;
294     }
295 
296     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
297     auto it = serverConfigMap_.find(netId);
298     if (it == serverConfigMap_.end()) {
299         DNS_CONFIG_PRINT("SetCacheDelayed failed: netid is not have netid:%{public}d,", netId);
300         return;
301     }
302 
303     it->second.SetCacheDelayed(hostName);
304 }
305 
AddUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)306 int32_t DnsParamCache::AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
307 {
308     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
309     NETNATIVE_LOG_D("DnsParamCache::AddUidRange size = [%{public}zu]", uidRanges.size());
310     vpnNetId_ = netId;
311     auto middle = vpnUidRanges_.insert(vpnUidRanges_.end(), uidRanges.begin(), uidRanges.end());
312     std::inplace_merge(vpnUidRanges_.begin(), middle, vpnUidRanges_.end());
313     return 0;
314 }
315 
DelUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)316 int32_t DnsParamCache::DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
317 {
318     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
319     NETNATIVE_LOG_D("DnsParamCache::DelUidRange size = [%{public}zu]", uidRanges.size());
320     vpnNetId_ = 0;
321     auto end = std::set_difference(vpnUidRanges_.begin(), vpnUidRanges_.end(), uidRanges.begin(),
322                                    uidRanges.end(), vpnUidRanges_.begin());
323     vpnUidRanges_.erase(end, vpnUidRanges_.end());
324     return 0;
325 }
326 
IsVpnOpen() const327 bool DnsParamCache::IsVpnOpen() const
328 {
329     return vpnUidRanges_.size();
330 }
331 
332 #ifdef FEATURE_NET_FIREWALL_ENABLE
GetUserId(int32_t appUid)333 int32_t DnsParamCache::GetUserId(int32_t appUid)
334 {
335     int32_t userId = appUid / USER_ID_DIVIDOR;
336     return userId > 0 ? userId : currentUserId_;
337 }
338 
GetDnsServersByAppUid(int32_t appUid,std::vector<std::string> & servers)339 bool DnsParamCache::GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)
340 {
341     if (netFirewallDnsRuleMap_.empty()) {
342         return false;
343     }
344     DNS_CONFIG_PRINT("GetDnsServersByAppUid: appUid=%{public}d", appUid);
345     auto it = netFirewallDnsRuleMap_.find(appUid);
346     if (it == netFirewallDnsRuleMap_.end()) {
347         // if appUid not found, try to find invalid appUid=0;
348         it = netFirewallDnsRuleMap_.find(0);
349     }
350     if (it != netFirewallDnsRuleMap_.end()) {
351         int32_t userId = GetUserId(appUid);
352         std::vector<sptr<NetFirewallDnsRule>> rules = it->second;
353         for (const auto &rule : rules) {
354             if (rule->userId != userId) {
355                 continue;
356             }
357             servers.emplace_back(rule->primaryDns);
358             servers.emplace_back(rule->standbyDns);
359         }
360         return true;
361     }
362     return false;
363 }
364 
SetFirewallRules(NetFirewallRuleType type,const std::vector<sptr<NetFirewallBaseRule>> & ruleList,bool isFinish)365 int32_t DnsParamCache::SetFirewallRules(NetFirewallRuleType type,
366                                         const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)
367 {
368     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
369     NETNATIVE_LOGI("SetFirewallRules: size=%{public}zu isFinish=%{public}" PRId32, ruleList.size(), isFinish);
370     if (ruleList.empty()) {
371         NETNATIVE_LOGE("SetFirewallRules: rules is empty");
372         return -1;
373     }
374     int32_t ret = 0;
375     switch (type) {
376         case NetFirewallRuleType::RULE_DNS: {
377             for (const auto &rule : ruleList) {
378                 firewallDnsRules_.emplace_back(firewall_rule_cast<NetFirewallDnsRule>(rule));
379             }
380             if (isFinish) {
381                 ret = SetFirewallDnsRules(firewallDnsRules_);
382                 firewallDnsRules_.clear();
383             }
384             break;
385         }
386         case NetFirewallRuleType::RULE_DOMAIN: {
387             for (const auto &rule : ruleList) {
388                 firewallDomainRules_.emplace_back(firewall_rule_cast<NetFirewallDomainRule>(rule));
389             }
390             if (isFinish) {
391                 ret = SetFirewallDomainRules(firewallDomainRules_);
392                 firewallDomainRules_.clear();
393             }
394             break;
395         }
396         default:
397             break;
398     }
399     return ret;
400 }
401 
SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> & ruleList)402 int32_t DnsParamCache::SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)
403 {
404     for (const auto &rule : ruleList) {
405         std::vector<sptr<NetFirewallDnsRule>> rules;
406         auto it = netFirewallDnsRuleMap_.find(rule->appUid);
407         if (it != netFirewallDnsRuleMap_.end()) {
408             rules = it->second;
409         }
410         rules.emplace_back(std::move(rule));
411         netFirewallDnsRuleMap_.emplace(rule->appUid, std::move(rules));
412     }
413     return 0;
414 }
415 
GetFirewallRuleAction(int32_t appUid,const std::vector<sptr<NetFirewallDomainRule>> & rules)416 FirewallRuleAction DnsParamCache::GetFirewallRuleAction(int32_t appUid,
417                                                         const std::vector<sptr<NetFirewallDomainRule>> &rules)
418 {
419     int32_t userId = GetUserId(appUid);
420     for (const auto &rule : rules) {
421         if (rule->userId != userId) {
422             continue;
423         }
424         if ((rule->appUid && appUid == rule->appUid) || !rule->appUid) {
425             return rule->ruleAction;
426         }
427     }
428 
429     return FirewallRuleAction::RULE_INVALID;
430 }
431 
checkEmpty4InterceptDomain(const std::string & hostName)432 bool DnsParamCache::checkEmpty4InterceptDomain(const std::string &hostName)
433 {
434     if (hostName.empty()) {
435         return true;
436     }
437     if (!netFirewallDomainRulesAllowMap_.empty() || !netFirewallDomainRulesDenyMap_.empty()) {
438         return false;
439     }
440     if (domainAllowLsmTrie_ && !domainAllowLsmTrie_->Empty()) {
441         return false;
442     }
443     return !domainDenyLsmTrie_ || domainDenyLsmTrie_->Empty();
444 }
445 
IsInterceptDomain(int32_t appUid,const std::string & hostName)446 bool DnsParamCache::IsInterceptDomain(int32_t appUid, const std::string &hostName)
447 {
448     if (checkEmpty4InterceptDomain(hostName)) {
449         return (firewallDefaultAction_ == FirewallRuleAction::RULE_DENY);
450     }
451     std::string host = hostName.substr(0, hostName.find(' '));
452     DNS_CONFIG_PRINT("IsInterceptDomain: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
453     std::transform(host.begin(), host.end(), host.begin(), ::tolower);
454     std::vector<sptr<NetFirewallDomainRule>> rules;
455     FirewallRuleAction exactAllowAction = FirewallRuleAction::RULE_INVALID;
456     auto it = netFirewallDomainRulesAllowMap_.find(host);
457     if (it != netFirewallDomainRulesAllowMap_.end()) {
458         rules = it->second;
459         exactAllowAction = GetFirewallRuleAction(appUid, rules);
460     }
461     FirewallRuleAction exactDenyAction = FirewallRuleAction::RULE_INVALID;
462     auto iter = netFirewallDomainRulesDenyMap_.find(host);
463     if (iter != netFirewallDomainRulesDenyMap_.end()) {
464         rules = iter->second;
465         exactDenyAction = GetFirewallRuleAction(appUid, rules);
466     }
467     FirewallRuleAction wildcardAllowAction = FirewallRuleAction::RULE_INVALID;
468     if (domainAllowLsmTrie_->LongestSuffixMatch(host, rules)) {
469         wildcardAllowAction = GetFirewallRuleAction(appUid, rules);
470     }
471     FirewallRuleAction wildcardDenyAction = FirewallRuleAction::RULE_INVALID;
472     if (domainDenyLsmTrie_->LongestSuffixMatch(host, rules)) {
473         wildcardDenyAction = GetFirewallRuleAction(appUid, rules);
474     }
475     bool allow = false;
476     bool deny = false;
477     if ((exactAllowAction != FirewallRuleAction::RULE_INVALID) ||
478         (wildcardAllowAction != FirewallRuleAction::RULE_INVALID)) {
479         allow = true;
480     }
481     if ((exactDenyAction != FirewallRuleAction::RULE_INVALID) ||
482         (wildcardDenyAction != FirewallRuleAction::RULE_INVALID)) {
483         deny = true;
484     }
485     if (allow && !deny) {
486         return false;
487     }
488     if (!allow && deny) {
489         return true;
490     }
491     return (firewallDefaultAction_ == FirewallRuleAction::RULE_DENY);
492 }
493 
SetFirewallDefaultAction(FirewallRuleAction inDefault,FirewallRuleAction outDefault)494 int32_t DnsParamCache::SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)
495 {
496     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
497     DNS_CONFIG_PRINT("SetFirewallDefaultAction: firewallDefaultAction_: %{public}d", (int)outDefault);
498     firewallDefaultAction_ = outDefault;
499     return 0;
500 }
501 
BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> & rule,const std::string & domain)502 void DnsParamCache::BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)
503 {
504     std::vector<sptr<NetFirewallDomainRule>> rules;
505     std::string suffix(domain);
506     auto wildcardCharIndex = suffix.find('*');
507     if (wildcardCharIndex != std::string::npos) {
508         suffix = suffix.substr(wildcardCharIndex + 1);
509     }
510     DNS_CONFIG_PRINT("BuildFirewallDomainLsmTrie: suffix: %{public}s", suffix.c_str());
511     std::transform(suffix.begin(), suffix.end(), suffix.begin(), ::tolower);
512     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
513         if (domainDenyLsmTrie_->LongestSuffixMatch(suffix, rules)) {
514             rules.emplace_back(std::move(rule));
515             domainDenyLsmTrie_->Update(suffix, rules);
516             return;
517         }
518         rules.emplace_back(std::move(rule));
519         domainDenyLsmTrie_->Insert(suffix, rules);
520     } else {
521         if (domainAllowLsmTrie_->LongestSuffixMatch(suffix, rules)) {
522             rules.emplace_back(std::move(rule));
523             domainAllowLsmTrie_->Update(suffix, rules);
524             return;
525         }
526         rules.emplace_back(std::move(rule));
527         domainAllowLsmTrie_->Insert(suffix, rules);
528     }
529 }
530 
BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> & rule,const std::string & raw)531 void DnsParamCache::BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)
532 {
533     DNS_CONFIG_PRINT("BuildFirewallDomainMap: domain: %{public}s", raw.c_str());
534     std::string domain(raw);
535     std::vector<sptr<NetFirewallDomainRule>> rules;
536     std::transform(domain.begin(), domain.end(), domain.begin(), ::tolower);
537     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
538         auto it = netFirewallDomainRulesDenyMap_.find(domain);
539         if (it != netFirewallDomainRulesDenyMap_.end()) {
540             rules = it->second;
541         }
542 
543         rules.emplace_back(std::move(rule));
544         netFirewallDomainRulesDenyMap_.emplace(domain, std::move(rules));
545     } else {
546         auto it = netFirewallDomainRulesAllowMap_.find(domain);
547         if (it != netFirewallDomainRulesAllowMap_.end()) {
548             rules = it->second;
549         }
550 
551         rules.emplace_back(rule);
552         netFirewallDomainRulesAllowMap_.emplace(domain, std::move(rules));
553     }
554 }
555 
SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> & ruleList)556 int32_t DnsParamCache::SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)
557 {
558     if (!domainAllowLsmTrie_) {
559         domainAllowLsmTrie_ =
560             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
561     }
562     if (!domainDenyLsmTrie_) {
563         domainDenyLsmTrie_ =
564             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
565     }
566     for (const auto &rule : ruleList) {
567         for (const auto &param : rule->domains) {
568             if (param.isWildcard) {
569                 BuildFirewallDomainLsmTrie(rule, param.domain);
570             } else {
571                 BuildFirewallDomainMap(rule, param.domain);
572             }
573         }
574     }
575     return 0;
576 }
577 
ClearFirewallRules(NetFirewallRuleType type)578 int32_t DnsParamCache::ClearFirewallRules(NetFirewallRuleType type)
579 {
580     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
581     switch (type) {
582         case NetFirewallRuleType::RULE_DNS:
583             firewallDnsRules_.clear();
584             netFirewallDnsRuleMap_.clear();
585             break;
586         case NetFirewallRuleType::RULE_DOMAIN: {
587             firewallDomainRules_.clear();
588             netFirewallDomainRulesAllowMap_.clear();
589             netFirewallDomainRulesDenyMap_.clear();
590             if (domainAllowLsmTrie_) {
591                 domainAllowLsmTrie_ = nullptr;
592             }
593             if (domainDenyLsmTrie_) {
594                 domainDenyLsmTrie_ = nullptr;
595             }
596             break;
597         }
598         case NetFirewallRuleType::RULE_ALL: {
599             firewallDnsRules_.clear();
600             netFirewallDnsRuleMap_.clear();
601             firewallDomainRules_.clear();
602             netFirewallDomainRulesAllowMap_.clear();
603             netFirewallDomainRulesDenyMap_.clear();
604             if (domainAllowLsmTrie_) {
605                 domainAllowLsmTrie_ = nullptr;
606             }
607             if (domainDenyLsmTrie_) {
608                 domainDenyLsmTrie_ = nullptr;
609             }
610             break;
611         }
612         default:
613             break;
614     }
615     return 0;
616 }
617 
NotifyDomianIntercept(int32_t appUid,const std::string & hostName)618 void DnsParamCache::NotifyDomianIntercept(int32_t appUid, const std::string &hostName)
619 {
620     if (hostName.empty()) {
621         return;
622     }
623     std::string host = hostName.substr(0, hostName.find(' '));
624     NETNATIVE_LOGI("NotifyDomianIntercept: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
625     sptr<NetManagerStandard::InterceptRecord> record = new (std::nothrow) NetManagerStandard::InterceptRecord();
626     record->time = (int32_t)time(NULL);
627     record->appUid = appUid;
628     record->domain = host;
629 
630     if (oldRecord_ != nullptr && (record->time - oldRecord_->time) < INTERCEPT_BUFF_INTERVAL_SEC) {
631         if (record->appUid == oldRecord_->appUid && record->domain == oldRecord_->domain) {
632             return;
633         }
634     }
635     oldRecord_ = record;
636     for (const auto &callback : callbacks_) {
637         callback->OnIntercept(record);
638     }
639 }
640 
RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)641 int32_t DnsParamCache::RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
642 {
643     if (!callback) {
644         return -1;
645     }
646 
647     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
648     callbacks_.emplace_back(callback);
649 
650     return 0;
651 }
652 
UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)653 int32_t DnsParamCache::UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
654 {
655     if (!callback) {
656         return -1;
657     }
658 
659     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
660     for (auto it = callbacks_.begin(); it != callbacks_.end(); ++it) {
661         if (*it == callback) {
662             callbacks_.erase(it);
663             return 0;
664         }
665     }
666     return -1;
667 }
668 #endif
669 
GetDumpInfo(std::string & info)670 void DnsParamCache::GetDumpInfo(std::string &info)
671 {
672     std::string dnsData;
673     static const std::string TAB = "  ";
674     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
675     std::for_each(serverConfigMap_.begin(), serverConfigMap_.end(), [&dnsData](const auto &serverConfig) {
676         dnsData.append(TAB + "NetId: " + std::to_string(serverConfig.second.GetNetId()) + "\n");
677         dnsData.append(TAB + "TimeoutMsec: " + std::to_string(serverConfig.second.GetTimeoutMsec()) + "\n");
678         dnsData.append(TAB + "RetryCount: " + std::to_string(serverConfig.second.GetRetryCount()) + "\n");
679         dnsData.append(TAB + "Servers:");
680         GetVectorData(serverConfig.second.GetServers(), dnsData);
681         dnsData.append(TAB + "Domains:");
682         GetVectorData(serverConfig.second.GetDomains(), dnsData);
683     });
684     info.append(dnsData);
685 }
686 } // namespace OHOS::nmd
687