1 /*
2  * Copyright (C) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dns_checker.h"
17 #include <arpa/inet.h>
18 #include <chrono>
19 #include <poll.h>
20 #include <net/if_arp.h>
21 #include <net/if.h>
22 #include <unistd.h>
23 
24 #include "securec.h"
25 #include "wifi_log.h"
26 #include "wifi_config_center.h"
27 
28 #ifdef LOG_TAG
29 #undef LOG_TAG
30 #endif
31 #define LOG_TAG "ohwifi_dns_checker"
32 
33 namespace OHOS {
34 namespace Wifi {
35 
36 const int DNS_SERVER_PORT = 53;
37 const int DNS_ADDRESS_TYPE = 1;
38 
39 struct DNS_HEADER {
40     unsigned short id;
41     unsigned char rd : 1;
42     unsigned char tc : 1;
43     unsigned char aa : 1;
44     unsigned char opCode : 4;
45     unsigned char qr : 1;
46 
47     unsigned char rCode : 4;
48     unsigned char cd : 1;
49     unsigned char ad : 1;
50     unsigned char z : 1;
51     unsigned char ra : 1;
52 
53     unsigned short qCount;
54     unsigned short ansCount;
55     unsigned short authCount;
56     unsigned short addCount;
57 };
58 
59 struct QUESTION {
60     unsigned short qtype;
61     unsigned short qclass;
62 };
63 
DnsChecker()64 DnsChecker::DnsChecker() : dnsSocket(0), socketCreated(false), isRunning(true)
65 {}
66 
~DnsChecker()67 DnsChecker::~DnsChecker()
68 {
69     Stop();
70 }
71 
Start(std::string priDns,std::string secondDns)72 void DnsChecker::Start(std::string priDns, std::string secondDns)
73 {
74     if (socketCreated) {
75         Stop();
76     }
77     isRunning = true;
78     dnsSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
79     if (dnsSocket < 0) {
80         LOGE("DnsChecker:socket error : %{public}d", errno);
81         dnsSocket = 0;
82         return;
83     }
84     std::string ifaceName = WifiConfigCenter::GetInstance().GetStaIfaceName();
85     struct ifreq ifaceReq;
86     if (strncpy_s(ifaceReq.ifr_name, sizeof(ifaceReq.ifr_name), ifaceName.c_str(), ifaceName.size()) != EOK) {
87         LOGE("DnsChecker copy ifaceReq failed.");
88         close(dnsSocket);
89         dnsSocket = 0;
90         return;
91     }
92     if (setsockopt(dnsSocket, SOL_SOCKET, SO_BINDTODEVICE, reinterpret_cast<char *>(&ifaceReq),
93         sizeof(ifaceReq)) == -1) {
94         LOGE("DnsChecker start SO_BINDTODEVICE error:%{public}d.", errno);
95         close(dnsSocket);
96         dnsSocket = 0;
97         return;
98     }
99     socketCreated = true;
100     primaryDnsAddress = priDns;
101     secondDnsAddress = secondDns;
102 }
103 
Stop()104 void DnsChecker::Stop()
105 {
106     if (!socketCreated) {
107         return;
108     }
109     close(dnsSocket);
110     dnsSocket = 0;
111     socketCreated = false;
112 }
113 
formatHostAdress(char * hostAddress,const char * host)114 void DnsChecker::formatHostAdress(char* hostAddress, const char* host)
115 {
116     if (!hostAddress || !host) {
117         return;
118     }
119     int lock = 0;
120     int len = strlen(host);
121     for (int i = 0; i < len; i++) {
122         if (host[i] == '.') {
123             *hostAddress++ = i - lock;
124             for (; lock < i; lock++) {
125                 *hostAddress++ = host[lock];
126             }
127             lock++;
128         }
129     }
130     *hostAddress++ = '\0';
131 }
StopDnsCheck()132 void DnsChecker::StopDnsCheck()
133 {
134     isRunning = false;
135 }
136 
DoDnsCheck(std::string url,int timeoutMillis)137 bool DnsChecker::DoDnsCheck(std::string url, int timeoutMillis)
138 {
139     LOGI("DoDnsCheck Enter.");
140     int len1 = static_cast<int>(url.find("/generate_204"));
141     int len =  len1 - strlen("http://");
142     std::string host = url.substr(strlen("http://"), len);
143     host = host + ".";
144     LOGD("DoDnsCheck url=%{public}s", host.c_str());
145     if (!isRunning) {
146         return false;
147     }
148     bool dnsValid = checkDnsValid(host, primaryDnsAddress, timeoutMillis) ||
149         checkDnsValid(host, secondDnsAddress, timeoutMillis);
150     if (!dnsValid) {
151         LOGE("all dns can not work.");
152     }
153     return dnsValid;
154 }
155 
recvDnsData(char * buff,int size,int timeout)156 int DnsChecker::recvDnsData(char* buff, int size, int timeout)
157 {
158     if (dnsSocket < 0) {
159         LOGE("invalid socket fd");
160         return 0;
161     }
162 
163     pollfd fds[1];
164     fds[0].fd = dnsSocket;
165     fds[0].events = POLLIN;
166     if (poll(fds, 1, timeout) <= 0) {
167         return 0;
168     }
169 
170     int nBytes;
171     do {
172         nBytes = read(dnsSocket, buff, size);
173         if (nBytes < 0) {
174             LOGE("recvfrom failed %{public}d", errno);
175             return false;
176         }
177     } while (nBytes == -1 && isRunning);
178 
179     return nBytes < 0 ? 0 : nBytes;
180 }
181 
checkDnsValid(std::string host,std::string dnsAddress,int timeoutMillis)182 bool DnsChecker::checkDnsValid(std::string host, std::string dnsAddress, int timeoutMillis)
183 {
184     if (!socketCreated && !isRunning) {
185         LOGE("DnsChecker checkDnsValid failed, socket not created");
186         return false;
187     }
188     if (dnsAddress.empty()) {
189         LOGE("DnsChecker dnsAddress is empty!");
190         return false;
191     }
192     struct sockaddr_in dest;
193     dest.sin_family = AF_INET;
194     dest.sin_port = htons(DNS_SERVER_PORT);
195     dest.sin_addr.s_addr = inet_addr(dnsAddress.c_str());
196     char buff[2048] = {0};
197     struct DNS_HEADER *dns = (struct DNS_HEADER*)&buff;
198     dns->id = (unsigned short)htons(getpid());
199     dns->rd = 1;
200     dns->qCount = htons(1);
201     char* hostAddress = static_cast<char*>(&buff[sizeof(struct DNS_HEADER)]);
202     formatHostAdress(hostAddress, host.c_str());
203     struct QUESTION *qinfo = (struct QUESTION *)&buff[sizeof(struct DNS_HEADER) +
204         (strlen(static_cast<const char*>(hostAddress)) + 1)];
205     qinfo->qtype = htons(DNS_ADDRESS_TYPE);
206     qinfo->qclass = htons(1);
207     int len = static_cast<int>(sizeof(struct DNS_HEADER) +
208         (strlen(static_cast<const char*>(hostAddress) + 1) + sizeof(QUESTION)));
209     if (sendto(dnsSocket, buff, len, 0, (struct sockaddr*)&dest, sizeof(dest)) < 0) {
210         LOGE("send dns data failed.");
211         return false;
212     }
213     int64_t elapsed = 0;
214     int leftMillis = timeoutMillis;
215     std::chrono::steady_clock::time_point startTime = std::chrono::steady_clock::now();
216     while (leftMillis > 0 && isRunning) {
217         int readLen = recvDnsData(buff, sizeof(buff), leftMillis);
218         if (readLen >= static_cast<int>(sizeof(struct DNS_HEADER))) {
219             dns = reinterpret_cast<struct DNS_HEADER*>(buff);
220             LOGI("dns recv ansCount:%{public}d", dns->ansCount);
221             return dns->ansCount > 0;
222         }
223         std::chrono::steady_clock::time_point current = std::chrono::steady_clock::now();
224         elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(current - startTime).count();
225         leftMillis -= static_cast<int>(elapsed);
226     }
227     return false;
228 }
229 }
230 }
231