1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "netlink_socket_diag.h"
17 
18 #include <arpa/inet.h>
19 #include <cstring>
20 #include <net/if.h>
21 #include <netinet/tcp.h>
22 #include <sys/uio.h>
23 #include <unistd.h>
24 
25 #include "fwmark.h"
26 #include "net_manager_constants.h"
27 #include "netmanager_base_common_utils.h"
28 #include "netnative_log_wrapper.h"
29 #include "securec.h"
30 
31 namespace OHOS {
32 namespace nmd {
33 using namespace NetManagerStandard;
34 
35 namespace {
36 constexpr uint32_t KERNEL_BUFFER_SIZE = 8192U;
37 constexpr uint8_t ADDR_POSITION = 3U;
38 constexpr int32_t DOMAIN_IP_ADDR_MAX_LEN = 128;
39 constexpr uint32_t LOCKBACK_MASK = 0xff000000;
40 constexpr uint32_t LOCKBACK_DEFINE = 0x7f000000;
41 constexpr uid_t PUSH_UID = 7023;
42 } // namespace
43 
~NetLinkSocketDiag()44 NetLinkSocketDiag::~NetLinkSocketDiag()
45 {
46     CloseNetlinkSocket();
47 }
48 
InLookBack(uint32_t a)49 bool NetLinkSocketDiag::InLookBack(uint32_t a)
50 {
51     return (a & LOCKBACK_MASK) == LOCKBACK_DEFINE;
52 }
53 
CreateNetlinkSocket()54 bool NetLinkSocketDiag::CreateNetlinkSocket()
55 {
56     dumpSock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
57     if (dumpSock_ < 0) {
58         NETNATIVE_LOGE("Create netlink socket for dump failed, error[%{public}d]: %{public}s", errno, strerror(errno));
59         return false;
60     }
61 
62     destroySock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
63     if (destroySock_ < 0) {
64         NETNATIVE_LOGE("Create netlink socket for destroy failed, error[%{public}d]: %{public}s", errno,
65                        strerror(errno));
66         close(dumpSock_);
67         return false;
68     }
69 
70     sockaddr_nl nl = {.nl_family = AF_NETLINK};
71     if ((connect(dumpSock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0) ||
72         (connect(destroySock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0)) {
73         NETNATIVE_LOGE("Connect to netlink socket failed, error[%{public}d]: %{public}s", errno, strerror(errno));
74         CloseNetlinkSocket();
75         return false;
76     }
77     return true;
78 }
79 
CloseNetlinkSocket()80 void NetLinkSocketDiag::CloseNetlinkSocket()
81 {
82     close(dumpSock_);
83     close(destroySock_);
84     dumpSock_ = -1;
85     destroySock_ = -1;
86 }
87 
ExecuteDestroySocket(uint8_t proto,const inet_diag_msg * msg)88 int32_t NetLinkSocketDiag::ExecuteDestroySocket(uint8_t proto, const inet_diag_msg *msg)
89 {
90     if (msg == nullptr) {
91         NETNATIVE_LOGE("inet_diag_msg is nullptr");
92         return NETMANAGER_ERR_LOCAL_PTR_NULL;
93     }
94 
95     SockDiagRequest request;
96     request.nlh_.nlmsg_type = SOCK_DESTROY;
97     request.nlh_.nlmsg_flags = NLM_F_REQUEST;
98     request.nlh_.nlmsg_len = sizeof(request);
99 
100     request.req_ = {.sdiag_family = msg->idiag_family,
101                     .sdiag_protocol = proto,
102                     .idiag_states = static_cast<uint32_t>(1 << msg->idiag_state),
103                     .id = msg->id};
104     ssize_t writeLen = write(destroySock_, &request, sizeof(request));
105     if (writeLen < static_cast<ssize_t>(sizeof(request))) {
106         NETNATIVE_LOGE("Write destroy request to socket failed errno[%{public}d]: strerror:%{public}s", errno,
107                        strerror(errno));
108         return NETMANAGER_ERR_INTERNAL;
109     }
110 
111     int32_t ret = GetErrorFromKernel(destroySock_);
112     if (ret == NETMANAGER_SUCCESS) {
113         socketsDestroyed_++;
114     }
115     return ret;
116 }
117 
GetErrorFromKernel(int32_t fd)118 int32_t NetLinkSocketDiag::GetErrorFromKernel(int32_t fd)
119 {
120     Ack ack;
121     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
122     if (bytesread < 0) {
123         NETNATIVE_LOGE("Get error info from kernel failed errno[%{public}d]: strerror:%{public}s", errno,
124                        strerror(errno));
125         return (errno == EAGAIN) ? NETMANAGER_SUCCESS : -errno;
126     }
127     if (bytesread == static_cast<ssize_t>(sizeof(ack)) && ack.hdr_.nlmsg_type == NLMSG_ERROR) {
128         recv(fd, &ack, sizeof(ack), 0);
129         NETNATIVE_LOGE("Receive NLMSG_ERROR:[%{public}d] from kernel", ack.err_.error);
130         return NETMANAGER_ERR_INTERNAL;
131     }
132     return NETMANAGER_SUCCESS;
133 }
134 
IsLoopbackSocket(const inet_diag_msg * msg)135 bool NetLinkSocketDiag::IsLoopbackSocket(const inet_diag_msg *msg)
136 {
137     if (msg->idiag_family == AF_INET) {
138         return InLookBack(htonl(msg->id.idiag_src[0])) || InLookBack(htonl(msg->id.idiag_dst[0]));
139     }
140 
141     if (msg->idiag_family == AF_INET6) {
142         const struct in6_addr *src = (const struct in6_addr *)&msg->id.idiag_src;
143         const struct in6_addr *dst = (const struct in6_addr *)&msg->id.idiag_dst;
144         return (IN6_IS_ADDR_V4MAPPED(src) && InLookBack(src->s6_addr32[ADDR_POSITION])) ||
145                (IN6_IS_ADDR_V4MAPPED(dst) && InLookBack(dst->s6_addr32[ADDR_POSITION])) || IN6_IS_ADDR_LOOPBACK(src) ||
146                IN6_IS_ADDR_LOOPBACK(dst);
147     }
148     return false;
149 }
150 
IsMatchNetwork(const inet_diag_msg * msg,const std::string & ipAddr)151 bool NetLinkSocketDiag::IsMatchNetwork(const inet_diag_msg *msg, const std::string &ipAddr)
152 {
153     if (msg->idiag_family == AF_INET) {
154         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET) {
155             return false;
156         }
157 
158         in_addr_t addr = inet_addr(ipAddr.c_str());
159         if (addr == msg->id.idiag_src[0] || addr == msg->id.idiag_dst[0]) {
160             return true;
161         }
162     }
163 
164     if (msg->idiag_family == AF_INET6) {
165         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET6) {
166             return false;
167         }
168 
169         char src[DOMAIN_IP_ADDR_MAX_LEN] = {0};
170         char dst[DOMAIN_IP_ADDR_MAX_LEN] = {0};
171         inet_ntop(AF_INET6, msg->id.idiag_src, src, sizeof(src));
172         inet_ntop(AF_INET6, msg->id.idiag_dst, dst, sizeof(dst));
173         if (src == ipAddr || dst == ipAddr) {
174             return true;
175         }
176     }
177     return false;
178 }
179 
ProcessSockDiagDumpResponse(uint8_t proto,const std::string & ipAddr,bool excludeLoopback)180 int32_t NetLinkSocketDiag::ProcessSockDiagDumpResponse(uint8_t proto, const std::string &ipAddr, bool excludeLoopback)
181 {
182     char buf[KERNEL_BUFFER_SIZE] = {0};
183     ssize_t readBytes = read(dumpSock_, buf, sizeof(buf));
184     if (readBytes < 0) {
185         NETNATIVE_LOGE("Failed to read socket, errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
186         return NETMANAGER_ERR_INTERNAL;
187     }
188     while (readBytes > 0) {
189         uint32_t len = static_cast<uint32_t>(readBytes);
190         for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf); NLMSG_OK(nlh, len); nlh = NLMSG_NEXT(nlh, len)) {
191             if (nlh->nlmsg_type == NLMSG_ERROR) {
192                 nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
193                 NETNATIVE_LOGE("Error netlink msg, errno:%{public}d, strerror:%{public}s", -err->error,
194                                strerror(-err->error));
195                 return err->error;
196             } else if (nlh->nlmsg_type == NLMSG_DONE) {
197                 return NETMANAGER_SUCCESS;
198             } else {
199                 const auto *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
200                 SockDiagDumpCallback(proto, msg, ipAddr, excludeLoopback);
201             }
202         }
203         readBytes = read(dumpSock_, buf, sizeof(buf));
204         if (readBytes < 0) {
205             return -errno;
206         }
207     }
208     return NETMANAGER_SUCCESS;
209 }
210 
SendSockDiagDumpRequest(uint8_t proto,uint8_t family,uint32_t states)211 int32_t NetLinkSocketDiag::SendSockDiagDumpRequest(uint8_t proto, uint8_t family, uint32_t states)
212 {
213     SockDiagRequest request;
214     size_t len = sizeof(request);
215     iovec iov;
216     iov.iov_base = &request;
217     iov.iov_len = len;
218     request.nlh_.nlmsg_type = SOCK_DIAG_BY_FAMILY;
219     request.nlh_.nlmsg_flags = (NLM_F_REQUEST | NLM_F_DUMP);
220     request.nlh_.nlmsg_len = len;
221 
222     request.req_ = {.sdiag_family = family, .sdiag_protocol = proto, .idiag_states = states};
223 
224     ssize_t writeLen = writev(dumpSock_, &iov, (sizeof(iov) / sizeof(iovec)));
225     if (writeLen != static_cast<ssize_t>(len)) {
226         NETNATIVE_LOGE("Write dump request failed errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
227         return NETMANAGER_ERR_INTERNAL;
228     }
229 
230     return GetErrorFromKernel(dumpSock_);
231 }
232 
SockDiagDumpCallback(uint8_t proto,const inet_diag_msg * msg,const std::string & ipAddr,bool excludeLoopback)233 void NetLinkSocketDiag::SockDiagDumpCallback(uint8_t proto, const inet_diag_msg *msg, const std::string &ipAddr,
234                                              bool excludeLoopback)
235 {
236     if (msg == nullptr) {
237         NETNATIVE_LOGE("msg is nullptr");
238         return;
239     }
240 
241     if (socketDestroyType_ == SocketDestroyType::DESTROY_SPECIAL_CELLULAR && msg->idiag_uid != PUSH_UID) {
242         return;
243     }
244 
245     if (socketDestroyType_ == SocketDestroyType::DESTROY_DEFAULT_CELLULAR && msg->idiag_uid == PUSH_UID) {
246         return;
247     }
248 
249     if (excludeLoopback && IsLoopbackSocket(msg)) {
250         NETNATIVE_LOGE("Loop back socket, no need to close.");
251         return;
252     }
253 
254     if (!IsMatchNetwork(msg, ipAddr)) {
255         NETNATIVE_LOG_D("Socket is not associated with the network");
256         return;
257     }
258 
259     ExecuteDestroySocket(proto, msg);
260 }
261 
DestroyLiveSockets(const char * ipAddr,bool excludeLoopback)262 void NetLinkSocketDiag::DestroyLiveSockets(const char *ipAddr, bool excludeLoopback)
263 {
264     NETNATIVE_LOG_D("DestroySocketsLackingNetwork in");
265     if (ipAddr == nullptr) {
266         NETNATIVE_LOGE("Ip address is nullptr.");
267         return;
268     }
269 
270     if (!CreateNetlinkSocket()) {
271         NETNATIVE_LOGE("Create netlink diag socket failed.");
272         return;
273     }
274 
275     const int32_t proto = IPPROTO_TCP;
276     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
277 
278     for (const int family : {AF_INET, AF_INET6}) {
279         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
280         if (ret != NETMANAGER_SUCCESS) {
281             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
282             break;
283         }
284         ret = ProcessSockDiagDumpResponse(proto, ipAddr, excludeLoopback);
285         if (ret != NETMANAGER_SUCCESS) {
286             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
287             break;
288         }
289     }
290 
291     NETNATIVE_LOG_D("Destroyed %{public}d sockets", socketsDestroyed_);
292 }
293 
SetSocketDestroyType(const std::string & netCapabilities)294 int32_t NetLinkSocketDiag::SetSocketDestroyType(const std::string &netCapabilities)
295 {
296     const std::string capSpecialCellularStr = "NET_CAPABILITY_INTERNAL_DEFAULT";
297     const std::string bearerCellularStr = "BEARER_CELLULAR";
298     if (netCapabilities.find(capSpecialCellularStr) != std::string::npos) {
299         socketDestroyType_ = SocketDestroyType::DESTROY_SPECIAL_CELLULAR;
300     } else if (netCapabilities.find(bearerCellularStr) != std::string::npos) {
301         socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT_CELLULAR;
302     } else {
303         socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT;
304     }
305     return 0;
306 }
307 } // namespace nmd
308 } // namespace OHOS