1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <netinet/ip.h>
17 #include <netinet/udp.h>
18 #include <thread>
19 #include <pthread.h>
20 #include <unistd.h>
21 #include <sys/epoll.h>
22 
23 #include "dns_config_client.h"
24 #include "dns_param_cache.h"
25 #include "netnative_log_wrapper.h"
26 #include "netsys_udp_transfer.h"
27 #include "singleton.h"
28 #include "ffrt.h"
29 
30 #include "dns_proxy_listen.h"
31 
32 namespace OHOS {
33 namespace nmd {
34 uint16_t DnsProxyListen::netId_ = 0;
35 std::atomic_bool DnsProxyListen::proxyListenSwitch_ = false;
36 std::mutex DnsProxyListen::listenerMutex_;
37 constexpr uint16_t DNS_PROXY_PORT = 53;
38 constexpr uint8_t RESPONSE_FLAG = 0x80;
39 constexpr uint8_t RESPONSE_FLAG_USED = 80;
40 constexpr size_t FLAG_BUFF_LEN = 1;
41 constexpr size_t FLAG_BUFF_OFFSET = 2;
42 constexpr size_t DNS_HEAD_LENGTH = 12;
43 constexpr int32_t EPOLL_TASK_NUMBER = 10;
DnsProxyListen()44 DnsProxyListen::DnsProxyListen() : proxySockFd_(-1), proxySockFd6_(-1) {}
~DnsProxyListen()45 DnsProxyListen::~DnsProxyListen()
46 {
47     if (proxySockFd_ > 0) {
48         close(proxySockFd_);
49         proxySockFd_ = -1;
50     }
51     if (proxySockFd6_ > 0) {
52         close(proxySockFd6_);
53         proxySockFd6_ = -1;
54     }
55     if (epollFd_ > 0) {
56         close(epollFd_);
57         epollFd_ = -1;
58     }
59     serverIdxOfSocket.clear();
60 }
61 
DnsParseBySocket(std::unique_ptr<RecvBuff> & recvBuff,std::unique_ptr<AlignedSockAddr> & clientSock)62 void DnsProxyListen::DnsParseBySocket(std::unique_ptr<RecvBuff> &recvBuff, std::unique_ptr<AlignedSockAddr> &clientSock)
63 {
64     int32_t socketFd = -1;
65     if (clientSock->sa.sa_family == AF_INET) {
66         socketFd = socket(AF_INET, SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK, IPPROTO_UDP);
67     } else if (clientSock->sa.sa_family == AF_INET6) {
68         socketFd = socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK, IPPROTO_UDP);
69     }
70     if (socketFd < 0) {
71         NETNATIVE_LOGE("socketFd create socket failed %{public}d", errno);
72         return;
73     }
74     if (!PollUdpDataTransfer::MakeUdpNonBlock(socketFd)) {
75         NETNATIVE_LOGE("MakeNonBlock error %{public}d: %{public}s", errno, strerror(errno));
76         close(socketFd);
77         return;
78     }
79     serverIdxOfSocket.emplace(std::piecewise_construct, std::forward_as_tuple(socketFd),
80                               std::forward_as_tuple(socketFd, std::move(clientSock), std::move(recvBuff)));
81     SendRequest2Server(socketFd);
82 }
83 
GetDnsProxyServers(std::vector<std::string> & servers,size_t serverIdx)84 bool DnsProxyListen::GetDnsProxyServers(std::vector<std::string> &servers, size_t serverIdx)
85 {
86     std::vector<std::string> domains;
87     uint16_t baseTimeoutMsec;
88     uint8_t retryCount;
89     DnsParamCache::GetInstance().GetResolverConfig(DnsProxyListen::netId_, servers, domains, baseTimeoutMsec,
90                                                    retryCount);
91     if (serverIdx >= servers.size()) {
92         NETNATIVE_LOGE("no server useful");
93         return false;
94     }
95     return true;
96 }
97 
MakeAddrInfo(std::vector<std::string> & servers,size_t serverIdx,AlignedSockAddr & addrParse,AlignedSockAddr & clientSock)98 bool DnsProxyListen::MakeAddrInfo(std::vector<std::string> &servers, size_t serverIdx, AlignedSockAddr &addrParse,
99                                   AlignedSockAddr &clientSock)
100 {
101     if (clientSock.sa.sa_family == AF_INET) {
102         if (servers[serverIdx].find(".") == std::string::npos) {
103             return false;
104         }
105         addrParse.sin.sin_family = AF_INET;
106         addrParse.sin.sin_port = htons(DNS_PROXY_PORT);
107         addrParse.sin.sin_addr.s_addr = inet_addr(servers[serverIdx].c_str());
108         if (addrParse.sin.sin_addr.s_addr == INADDR_NONE) {
109             NETNATIVE_LOGE("Input ipv4 dns server %{private}s is not correct!", servers[serverIdx].c_str());
110             return false;
111         }
112     } else if (clientSock.sa.sa_family == AF_INET6) {
113         if (servers[serverIdx].find(":") == std::string::npos) {
114             return false;
115         }
116         addrParse.sin6.sin6_family = AF_INET6;
117         addrParse.sin6.sin6_port = htons(DNS_PROXY_PORT);
118         inet_pton(AF_INET6, servers[serverIdx].c_str(), &(addrParse.sin6.sin6_addr));
119         if (IN6_IS_ADDR_UNSPECIFIED(&addrParse.sin6.sin6_addr)) {
120             NETNATIVE_LOGE("Input ipv6 dns server %{private}s is not correct!", servers[serverIdx].c_str());
121             return false;
122         }
123     } else {
124         NETNATIVE_LOGE("current clientSock type is error!");
125         return false;
126     }
127     return true;
128 }
129 
SendRequest2Server(int32_t socketFd)130 void DnsProxyListen::SendRequest2Server(int32_t socketFd)
131 {
132     auto iter = serverIdxOfSocket.find(socketFd);
133     if (iter == serverIdxOfSocket.end()) {
134         NETNATIVE_LOGE("no idx found");
135         return;
136     }
137     auto serverIdx = iter->second.GetIdx();
138     std::vector<std::string> servers;
139     if (!GetDnsProxyServers(servers, serverIdx)) {
140         serverIdxOfSocket.erase(iter);
141         return;
142     }
143     iter->second.IncreaseIdx();
144     epoll_ctl(epollFd_, EPOLL_CTL_DEL, socketFd, nullptr);
145     socklen_t addrLen;
146     AlignedSockAddr &addrParse = iter->second.GetAddr();
147     AlignedSockAddr &clientSock = iter->second.GetClientSock();
148     if (!MakeAddrInfo(servers, serverIdx, addrParse, clientSock)) {
149         return SendRequest2Server(socketFd);
150     }
151     if (PollUdpDataTransfer::PollUdpSendData(socketFd, iter->second.GetRecvBuff().questionsBuff,
152                                              iter->second.GetRecvBuff().questionLen, addrParse, addrLen) < 0) {
153         NETNATIVE_LOGE("send failed %{public}d: %{public}s", errno, strerror(errno));
154         return SendRequest2Server(socketFd);
155     }
156     iter->second.endTime = std::chrono::system_clock::now() + std::chrono::milliseconds(EPOLL_TIMEOUT);
157     if (epoll_ctl(epollFd_, EPOLL_CTL_ADD, socketFd, iter->second.GetEventPtr()) < 0) {
158         NETNATIVE_LOGE("epoll add sock %{public}d failed, errno: %{public}d", socketFd, errno);
159         serverIdxOfSocket.erase(iter);
160     }
161 }
162 
SendDnsBack2Client(int32_t socketFd)163 void DnsProxyListen::SendDnsBack2Client(int32_t socketFd)
164 {
165     NETNATIVE_LOG_D("epoll send back to client.");
166     auto iter = serverIdxOfSocket.find(socketFd);
167     if (iter == serverIdxOfSocket.end()) {
168         NETNATIVE_LOGE("no idx found");
169         return;
170     }
171     AlignedSockAddr &addrParse = iter->second.GetAddr();
172     AlignedSockAddr &clientSock = iter->second.GetClientSock();
173     int32_t proxySocket = proxySockFd_;
174     socklen_t addrLen = 0;
175     if (clientSock.sa.sa_family == AF_INET) {
176         proxySocket = proxySockFd_;
177         addrLen = sizeof(sockaddr_in);
178     } else {
179         proxySocket = proxySockFd6_;
180         addrLen = sizeof(sockaddr_in6);
181     }
182     char requesData[MAX_REQUESTDATA_LEN] = {0};
183     int32_t resLen =
184         PollUdpDataTransfer::PollUdpRecvData(socketFd, requesData, MAX_REQUESTDATA_LEN, addrParse, addrLen);
185     if (resLen > 0 && CheckDnsResponse(requesData, MAX_REQUESTDATA_LEN)) {
186         NETNATIVE_LOG_D("send %{public}d back to client.", socketFd);
187         DnsSendRecvParseData(proxySocket, requesData, resLen, iter->second.GetClientSock());
188         serverIdxOfSocket.erase(iter);
189         return;
190     }
191     NETNATIVE_LOGE("response not correct, retry for next server.");
192     SendRequest2Server(socketFd);
193 }
194 
DnsSendRecvParseData(int32_t clientSocket,char * requesData,int32_t resLen,AlignedSockAddr & proxyAddr)195 void DnsProxyListen::DnsSendRecvParseData(int32_t clientSocket, char *requesData, int32_t resLen,
196                                           AlignedSockAddr &proxyAddr)
197 {
198     socklen_t addrLen = 0;
199     if (proxyAddr.sa.sa_family == AF_INET) {
200         addrLen = sizeof(sockaddr_in);
201     } else {
202         addrLen = sizeof(sockaddr_in6);
203     }
204     if (PollUdpDataTransfer::PollUdpSendData(clientSocket, requesData, resLen, proxyAddr, addrLen) < 0) {
205         NETNATIVE_LOGE("send failed %{public}d: %{public}s", errno, strerror(errno));
206     }
207 }
208 
StartListen()209 void DnsProxyListen::StartListen()
210 {
211     NETNATIVE_LOGI("StartListen proxySockFd_ : %{public}d, proxySockFd6_ : %{public}d", proxySockFd_, proxySockFd6_);
212     epoll_event proxyEvent;
213     epoll_event proxy6Event;
214     if (!InitForListening(proxyEvent, proxy6Event)) {
215         return;
216     }
217     epoll_event eventsReceived[EPOLL_TASK_NUMBER];
218     while (true) {
219         int32_t nfds =
220             epoll_wait(epollFd_, eventsReceived, EPOLL_TASK_NUMBER, serverIdxOfSocket.empty() ? -1 : EPOLL_TIMEOUT);
221         NETNATIVE_LOG_D("now socket num: %{public}zu", serverIdxOfSocket.size());
222         if (nfds < 0) {
223             NETNATIVE_LOGE("epoll errno: %{public}d", errno);
224             continue; // now ignore all errno.
225         }
226         if (nfds == 0) {
227             // dns timeout
228             EpollTimeout();
229             continue;
230         }
231         for (int i = 0; i < nfds; ++i) {
232             if (eventsReceived[i].data.fd == proxySockFd_ || eventsReceived[i].data.fd == proxySockFd6_) {
233                 int32_t family = (eventsReceived[i].data.fd == proxySockFd_) ? AF_INET : AF_INET6;
234                 GetRequestAndTransmit(family);
235             } else {
236                 SendDnsBack2Client(eventsReceived[i].data.fd);
237             }
238         }
239         CollectSocks();
240     }
241 }
GetRequestAndTransmit(int32_t family)242 void DnsProxyListen::GetRequestAndTransmit(int32_t family)
243 {
244     NETNATIVE_LOG_D("epoll got request from client.");
245     auto recvBuff = std::make_unique<RecvBuff>();
246     if (recvBuff == nullptr) {
247         NETNATIVE_LOGE("recvBuff mem failed");
248         return;
249     }
250     (void)memset_s(recvBuff->questionsBuff, MAX_REQUESTDATA_LEN, 0, MAX_REQUESTDATA_LEN);
251 
252     auto clientAddr = std::make_unique<AlignedSockAddr>();
253     if (clientAddr == nullptr) {
254         NETNATIVE_LOGE("clientAddr mem failed");
255         return;
256     }
257 
258     if (family == AF_INET) {
259         socklen_t len = sizeof(sockaddr_in);
260         recvBuff->questionLen = recvfrom(proxySockFd_, recvBuff->questionsBuff, MAX_REQUESTDATA_LEN, 0,
261                                          reinterpret_cast<sockaddr *>(&(clientAddr->sin)), &len);
262     } else {
263         socklen_t len = sizeof(sockaddr_in6);
264         recvBuff->questionLen = recvfrom(proxySockFd6_, recvBuff->questionsBuff, MAX_REQUESTDATA_LEN, 0,
265                                          reinterpret_cast<sockaddr *>(&(clientAddr->sin6)), &len);
266     }
267     if (recvBuff->questionLen <= 0) {
268         NETNATIVE_LOGE("read errno %{public}d", errno);
269         return;
270     }
271     if (!CheckDnsQuestion(recvBuff->questionsBuff, MAX_REQUESTDATA_LEN)) {
272         NETNATIVE_LOGE("read buff is not dns question");
273         return;
274     }
275     DnsParseBySocket(recvBuff, clientAddr);
276 }
277 
InitListenForIpv4()278 void DnsProxyListen::InitListenForIpv4()
279 {
280     std::lock_guard<std::mutex> lock(listenerMutex_);
281     if (proxySockFd_ < 0) {
282         proxySockFd_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
283         if (proxySockFd_ < 0) {
284             NETNATIVE_LOGE("proxySockFd_ create socket failed %{public}d", errno);
285             return;
286         }
287     }
288     sockaddr_in proxyAddr{};
289     proxyAddr.sin_family = AF_INET;
290     proxyAddr.sin_addr.s_addr = htonl(INADDR_ANY);
291     proxyAddr.sin_port = htons(DNS_PROXY_PORT);
292     if (bind(proxySockFd_, (sockaddr *)&proxyAddr, sizeof(proxyAddr)) == -1) {
293         NETNATIVE_LOGE("bind errno %{public}d: %{public}s", errno, strerror(errno));
294         close(proxySockFd_);
295         proxySockFd_ = -1;
296         return;
297     }
298 }
299 
InitListenForIpv6()300 void DnsProxyListen::InitListenForIpv6()
301 {
302     std::lock_guard<std::mutex> lock(listenerMutex_);
303     if (proxySockFd6_ < 0) {
304         proxySockFd6_ = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
305         if (proxySockFd6_ < 0) {
306             NETNATIVE_LOGE("proxySockFd_ create socket failed %{public}d", errno);
307             return;
308         }
309     }
310     sockaddr_in6 proxyAddr6{};
311     proxyAddr6.sin6_family = AF_INET6;
312     proxyAddr6.sin6_addr = in6addr_any;
313     proxyAddr6.sin6_port = htons(DNS_PROXY_PORT);
314     int on = 1;
315     if (setsockopt(proxySockFd6_, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
316         NETNATIVE_LOGE("setsockopt failed");
317         close(proxySockFd6_);
318         proxySockFd6_ = -1;
319         return;
320     }
321     if (bind(proxySockFd6_, (sockaddr *)&proxyAddr6, sizeof(proxyAddr6)) == -1) {
322         NETNATIVE_LOGE("bind6 errno %{public}d: %{public}s", errno, strerror(errno));
323         close(proxySockFd6_);
324         proxySockFd6_ = -1;
325         return;
326     }
327 }
328 
InitForListening(epoll_event & proxyEvent,epoll_event & proxy6Event)329 bool DnsProxyListen::InitForListening(epoll_event &proxyEvent, epoll_event &proxy6Event)
330 {
331     InitListenForIpv4();
332     InitListenForIpv6();
333     epollFd_ = epoll_create1(0);
334     if (epollFd_ < 0) {
335         NETNATIVE_LOGE("epoll_create1 errno %{public}d: %{public}s", errno, strerror(errno));
336         clearResource();
337         return false;
338     }
339     if (proxySockFd_ > 0) {
340         proxyEvent.data.fd = proxySockFd_;
341         proxyEvent.events = EPOLLIN;
342         if (epoll_ctl(epollFd_, EPOLL_CTL_ADD, proxySockFd_, &proxyEvent) < 0) {
343             NETNATIVE_LOGE("EPOLL_CTL_ADD proxy errno %{public}d: %{public}s", errno, strerror(errno));
344             clearResource();
345             return false;
346         }
347     }
348     if (proxySockFd6_ > 0) {
349         proxy6Event.data.fd = proxySockFd6_;
350         proxy6Event.events = EPOLLIN;
351         if (epoll_ctl(epollFd_, EPOLL_CTL_ADD, proxySockFd6_, &proxy6Event) < 0) {
352             NETNATIVE_LOGE("EPOLL_CTL_ADD proxy6 errno %{public}d: %{public}s", errno, strerror(errno));
353             clearResource();
354             return false;
355         }
356     }
357     if (proxySockFd_ < 0 && proxySockFd6_ < 0) {
358         NETNATIVE_LOGE("InitForListening ipv4/ipv6 error!");
359         clearResource();
360         return false;
361     }
362     collectTime = std::chrono::system_clock::now() + std::chrono::milliseconds(EPOLL_TIMEOUT);
363     return true;
364 }
365 
CollectSocks()366 void DnsProxyListen::CollectSocks()
367 {
368     if (std::chrono::system_clock::now() >= collectTime) {
369         NETNATIVE_LOG_D("collect socks");
370         std::list<int32_t> sockTemp;
371         for (const auto &[sock, request] : serverIdxOfSocket) {
372             if (std::chrono::system_clock::now() >= request.endTime) {
373                 sockTemp.push_back(sock);
374             }
375         }
376         for (const auto sock : sockTemp) {
377             SendRequest2Server(sock);
378         }
379         collectTime = std::chrono::system_clock::now() + std::chrono::milliseconds(EPOLL_TIMEOUT);
380     }
381 }
382 
EpollTimeout()383 void DnsProxyListen::EpollTimeout()
384 {
385     NETNATIVE_LOGE("epoll timeout, try next server.");
386     if (serverIdxOfSocket.size() > 0) {
387         std::list<int32_t> sockTemp;
388         std::transform(serverIdxOfSocket.cbegin(), serverIdxOfSocket.cend(), std::back_inserter(sockTemp),
389                        [](auto &iter) { return iter.first; });
390         for (const auto sock : sockTemp) {
391             SendRequest2Server(sock);
392         }
393     }
394     collectTime = std::chrono::system_clock::now() + std::chrono::milliseconds(EPOLL_TIMEOUT);
395 }
396 
CheckDnsQuestion(char * recBuff,size_t recLen)397 bool DnsProxyListen::CheckDnsQuestion(char *recBuff, size_t recLen)
398 {
399     if (recLen < DNS_HEAD_LENGTH) {
400         return false;
401     }
402     uint8_t flagBuff;
403     char *recFlagBuff = recBuff + FLAG_BUFF_OFFSET;
404     if (memcpy_s(reinterpret_cast<char *>(&flagBuff), FLAG_BUFF_LEN, recFlagBuff, FLAG_BUFF_LEN) != 0) {
405         return false;
406     }
407     int reqFlag = (flagBuff & RESPONSE_FLAG) / RESPONSE_FLAG_USED;
408     if (reqFlag) {
409         return false; // answer
410     } else {
411         return true; // question
412     }
413 }
414 
CheckDnsResponse(char * recBuff,size_t recLen)415 bool DnsProxyListen::CheckDnsResponse(char *recBuff, size_t recLen)
416 {
417     if (recLen < FLAG_BUFF_LEN + FLAG_BUFF_OFFSET) {
418         return false;
419     }
420     uint8_t flagBuff;
421     char *recFlagBuff = recBuff + FLAG_BUFF_OFFSET;
422     if (memcpy_s(reinterpret_cast<char *>(&flagBuff), FLAG_BUFF_LEN, recFlagBuff, FLAG_BUFF_LEN) != 0) {
423         return false;
424     }
425     int reqFlag = (flagBuff & RESPONSE_FLAG) / RESPONSE_FLAG_USED;
426     if (reqFlag) {
427         return true; // answer
428     } else {
429         return false; // question
430     }
431 }
432 
OnListen()433 void DnsProxyListen::OnListen()
434 {
435     DnsProxyListen::proxyListenSwitch_ = true;
436     NETNATIVE_LOGI("DnsProxy OnListen");
437 }
438 
OffListen()439 void DnsProxyListen::OffListen()
440 {
441     if (proxySockFd_ > 0) {
442         close(proxySockFd_);
443         proxySockFd_ = -1;
444     }
445     if (proxySockFd6_ > 0) {
446         close(proxySockFd6_);
447         proxySockFd6_ = -1;
448     }
449     NETNATIVE_LOGI("DnsProxy OffListen");
450 }
451 
SetParseNetId(uint16_t netId)452 void DnsProxyListen::SetParseNetId(uint16_t netId)
453 {
454     DnsProxyListen::netId_ = netId;
455     NETNATIVE_LOGI("SetParseNetId");
456 }
457 
clearResource()458 void DnsProxyListen::clearResource()
459 {
460     if (proxySockFd_ > 0) {
461         close(proxySockFd_);
462         proxySockFd_ = -1;
463     }
464     if (proxySockFd6_ > 0) {
465         close(proxySockFd6_);
466         proxySockFd6_ = -1;
467     }
468     if (epollFd_ > 0) {
469         close(epollFd_);
470         epollFd_ = -1;
471     }
472     serverIdxOfSocket.clear();
473 }
474 
475 template<typename... Args>
emplace(Args &&...args)476 auto DnsProxyListen::DnsSocketHolder::emplace(Args&&... args) ->
477 decltype(DnsSocketHolderBase::emplace(std::forward<Args>(args)...))
478 {
479     if (size() >= MAX_SOCKET_CAPACITY) {
480         NETNATIVE_LOG_D("Socket num over capacity, throw oldest socket.");
481         DnsSocketHolderBase::erase(lruCache.front());
482         lruCache.pop_front();
483     }
484     auto iter = DnsSocketHolderBase::emplace(std::forward<Args>(args)...);
485     iter.first->second.SetLruIterator(lruCache.insert(lruCache.end(), iter.first));
486     return iter;
487 }
488 
erase(iterator position)489 auto DnsProxyListen::DnsSocketHolder::erase(iterator position) -> decltype(DnsSocketHolderBase::erase(position))
490 {
491     lruCache.erase(position->second.GetLruIterator());
492     return DnsSocketHolderBase::erase(position);
493 }
494 } // namespace nmd
495 } // namespace OHOS
496