1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef INCLUDE_NETLINK_SOCK_DIAG_H
17 #define INCLUDE_NETLINK_SOCK_DIAG_H
18 
19 #include <linux/netlink.h>
20 #include <linux/sock_diag.h>
21 #include <linux/inet_diag.h>
22 #include <netinet/in.h>
23 #include <sys/socket.h>
24 #include <string>
25 #include <unistd.h>
26 
27 namespace OHOS {
28 namespace nmd {
29 namespace {
30     enum class SocketDestroyType {
31         DESTROY_DEFAULT_CELLULAR,
32         DESTROY_SPECIAL_CELLULAR,
33         DESTROY_DEFAULT,
34     };
35 }
36 class NetLinkSocketDiag final {
37 public:
38     NetLinkSocketDiag() = default;
39     ~NetLinkSocketDiag();
40 
41     /**
42      * Destroy all 'active' TCP sockets that no longer exist.
43      *
44      * @param ipAddr Network IP address
45      * @param excludeLoopback “true” to exclude loopback.
46      */
47     void DestroyLiveSockets(const char *ipAddr, bool excludeLoopback);
48 
49     /**
50      * This method set the socketDestroyType_, which used to choose the correct socket.
51      * to destroy.
52      * @param netCapabilities Net capabilities in string format.
53      * @return The result of the method is returned.
54      */
55     int32_t SetSocketDestroyType(const std::string &netCapabilities);
56 
57 private:
58     static bool InLookBack(uint32_t a);
59 
60     bool CreateNetlinkSocket();
61     void CloseNetlinkSocket();
62     int32_t ExecuteDestroySocket(uint8_t proto, const inet_diag_msg *msg);
63     int32_t GetErrorFromKernel(int32_t fd);
64     bool IsLoopbackSocket(const inet_diag_msg *msg);
65     bool IsMatchNetwork(const inet_diag_msg *msg, const std::string &ipAddr);
66     int32_t ProcessSockDiagDumpResponse(uint8_t proto, const std::string &ipAddr, bool excludeLoopback);
67     int32_t SendSockDiagDumpRequest(uint8_t proto, uint8_t family, uint32_t states);
68     void SockDiagDumpCallback(uint8_t proto, const inet_diag_msg *msg, const std::string &ipAddr, bool excludeLoopback);
69 
70 private:
71     struct SockDiagRequest {
72         nlmsghdr nlh_;
73         inet_diag_req_v2 req_;
74     };
75     struct MarkMatch {
76         inet_diag_bc_op op_;
77         uint32_t mark_;
78         uint32_t mask_;
79     };
80     struct ByteCode {
81         MarkMatch netIdMatch_;
82         MarkMatch controlMatch_;
83         inet_diag_bc_op controlJump_;
84     };
85     struct Ack {
86         nlmsghdr hdr_;
87         nlmsgerr err_;
88     };
89 
90     int32_t dumpSock_ = -1;
91     int32_t destroySock_ = -1;
92     int32_t socketsDestroyed_ = 0;
93     SocketDestroyType socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT;
94 };
95 } // namespace nmd
96 } // namespace OHOS
97 #endif // INCLUDE_NETLINK_SOCK_DIAG_H