1 /*
2 * Copyright (C) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "dns_checker.h"
17 #include <arpa/inet.h>
18 #include <chrono>
19 #include <poll.h>
20 #include <net/if_arp.h>
21 #include <net/if.h>
22 #include <unistd.h>
23
24 #include "securec.h"
25 #include "wifi_log.h"
26 #include "wifi_config_center.h"
27
28 #ifdef LOG_TAG
29 #undef LOG_TAG
30 #endif
31 #define LOG_TAG "ohwifi_dns_checker"
32
33 namespace OHOS {
34 namespace Wifi {
35
36 const int DNS_SERVER_PORT = 53;
37 const int DNS_ADDRESS_TYPE = 1;
38
39 struct DNS_HEADER {
40 unsigned short id;
41 unsigned char rd : 1;
42 unsigned char tc : 1;
43 unsigned char aa : 1;
44 unsigned char opCode : 4;
45 unsigned char qr : 1;
46
47 unsigned char rCode : 4;
48 unsigned char cd : 1;
49 unsigned char ad : 1;
50 unsigned char z : 1;
51 unsigned char ra : 1;
52
53 unsigned short qCount;
54 unsigned short ansCount;
55 unsigned short authCount;
56 unsigned short addCount;
57 };
58
59 struct QUESTION {
60 unsigned short qtype;
61 unsigned short qclass;
62 };
63
DnsChecker()64 DnsChecker::DnsChecker() : dnsSocket(0), socketCreated(false), isRunning(true)
65 {}
66
~DnsChecker()67 DnsChecker::~DnsChecker()
68 {
69 Stop();
70 }
71
Start(std::string priDns,std::string secondDns)72 void DnsChecker::Start(std::string priDns, std::string secondDns)
73 {
74 if (socketCreated) {
75 Stop();
76 }
77 isRunning = true;
78 dnsSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
79 if (dnsSocket < 0) {
80 LOGE("DnsChecker:socket error : %{public}d", errno);
81 dnsSocket = 0;
82 return;
83 }
84 std::string ifaceName = WifiConfigCenter::GetInstance().GetStaIfaceName();
85 struct ifreq ifaceReq;
86 if (strncpy_s(ifaceReq.ifr_name, sizeof(ifaceReq.ifr_name), ifaceName.c_str(), ifaceName.size()) != EOK) {
87 LOGE("DnsChecker copy ifaceReq failed.");
88 close(dnsSocket);
89 dnsSocket = 0;
90 return;
91 }
92 if (setsockopt(dnsSocket, SOL_SOCKET, SO_BINDTODEVICE, reinterpret_cast<char *>(&ifaceReq),
93 sizeof(ifaceReq)) == -1) {
94 LOGE("DnsChecker start SO_BINDTODEVICE error:%{public}d.", errno);
95 close(dnsSocket);
96 dnsSocket = 0;
97 return;
98 }
99 socketCreated = true;
100 primaryDnsAddress = priDns;
101 secondDnsAddress = secondDns;
102 }
103
Stop()104 void DnsChecker::Stop()
105 {
106 if (!socketCreated) {
107 return;
108 }
109 close(dnsSocket);
110 dnsSocket = 0;
111 socketCreated = false;
112 }
113
formatHostAdress(char * hostAddress,const char * host)114 void DnsChecker::formatHostAdress(char* hostAddress, const char* host)
115 {
116 if (!hostAddress || !host) {
117 return;
118 }
119 int lock = 0;
120 int len = strlen(host);
121 for (int i = 0; i < len; i++) {
122 if (host[i] == '.') {
123 *hostAddress++ = i - lock;
124 for (; lock < i; lock++) {
125 *hostAddress++ = host[lock];
126 }
127 lock++;
128 }
129 }
130 *hostAddress++ = '\0';
131 }
StopDnsCheck()132 void DnsChecker::StopDnsCheck()
133 {
134 isRunning = false;
135 }
136
DoDnsCheck(std::string url,int timeoutMillis)137 bool DnsChecker::DoDnsCheck(std::string url, int timeoutMillis)
138 {
139 LOGI("DoDnsCheck Enter.");
140 int len1 = static_cast<int>(url.find("/generate_204"));
141 int len = len1 - strlen("http://");
142 std::string host = url.substr(strlen("http://"), len);
143 host = host + ".";
144 LOGD("DoDnsCheck url=%{public}s", host.c_str());
145 if (!isRunning) {
146 return false;
147 }
148 bool dnsValid = checkDnsValid(host, primaryDnsAddress, timeoutMillis) ||
149 checkDnsValid(host, secondDnsAddress, timeoutMillis);
150 if (!dnsValid) {
151 LOGE("all dns can not work.");
152 }
153 return dnsValid;
154 }
155
recvDnsData(char * buff,int size,int timeout)156 int DnsChecker::recvDnsData(char* buff, int size, int timeout)
157 {
158 if (dnsSocket < 0) {
159 LOGE("invalid socket fd");
160 return 0;
161 }
162
163 pollfd fds[1];
164 fds[0].fd = dnsSocket;
165 fds[0].events = POLLIN;
166 if (poll(fds, 1, timeout) <= 0) {
167 return 0;
168 }
169
170 int nBytes;
171 do {
172 nBytes = read(dnsSocket, buff, size);
173 if (nBytes < 0) {
174 LOGE("recvfrom failed %{public}d", errno);
175 return false;
176 }
177 } while (nBytes == -1 && isRunning);
178
179 return nBytes < 0 ? 0 : nBytes;
180 }
181
checkDnsValid(std::string host,std::string dnsAddress,int timeoutMillis)182 bool DnsChecker::checkDnsValid(std::string host, std::string dnsAddress, int timeoutMillis)
183 {
184 if (!socketCreated && !isRunning) {
185 LOGE("DnsChecker checkDnsValid failed, socket not created");
186 return false;
187 }
188 if (dnsAddress.empty()) {
189 LOGE("DnsChecker dnsAddress is empty!");
190 return false;
191 }
192 struct sockaddr_in dest;
193 dest.sin_family = AF_INET;
194 dest.sin_port = htons(DNS_SERVER_PORT);
195 dest.sin_addr.s_addr = inet_addr(dnsAddress.c_str());
196 char buff[2048] = {0};
197 struct DNS_HEADER *dns = (struct DNS_HEADER*)&buff;
198 dns->id = (unsigned short)htons(getpid());
199 dns->rd = 1;
200 dns->qCount = htons(1);
201 char* hostAddress = static_cast<char*>(&buff[sizeof(struct DNS_HEADER)]);
202 formatHostAdress(hostAddress, host.c_str());
203 struct QUESTION *qinfo = (struct QUESTION *)&buff[sizeof(struct DNS_HEADER) +
204 (strlen(static_cast<const char*>(hostAddress)) + 1)];
205 qinfo->qtype = htons(DNS_ADDRESS_TYPE);
206 qinfo->qclass = htons(1);
207 int len = static_cast<int>(sizeof(struct DNS_HEADER) +
208 (strlen(static_cast<const char*>(hostAddress) + 1) + sizeof(QUESTION)));
209 if (sendto(dnsSocket, buff, len, 0, (struct sockaddr*)&dest, sizeof(dest)) < 0) {
210 LOGE("send dns data failed.");
211 return false;
212 }
213 int64_t elapsed = 0;
214 int leftMillis = timeoutMillis;
215 std::chrono::steady_clock::time_point startTime = std::chrono::steady_clock::now();
216 while (leftMillis > 0 && isRunning) {
217 int readLen = recvDnsData(buff, sizeof(buff), leftMillis);
218 if (readLen >= static_cast<int>(sizeof(struct DNS_HEADER))) {
219 dns = reinterpret_cast<struct DNS_HEADER*>(buff);
220 LOGI("dns recv ansCount:%{public}d", dns->ansCount);
221 return dns->ansCount > 0;
222 }
223 std::chrono::steady_clock::time_point current = std::chrono::steady_clock::now();
224 elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(current - startTime).count();
225 leftMillis -= static_cast<int>(elapsed);
226 }
227 return false;
228 }
229 }
230 }
231