1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "netlink/netlink_listener.h"
17 
18 #include <memory>
19 #include <sys/socket.h>
20 #include <unistd.h>
21 #include <linux/netlink.h>
22 
23 #include "securec.h"
24 #include "storage_service_errno.h"
25 #include "storage_service_log.h"
26 
27 constexpr int POLL_IDLE_TIME = 1000;
28 constexpr int UEVENT_MSG_LEN = 1024;
29 
30 namespace OHOS {
31 namespace StorageDaemon {
UeventKernelMulticastRecv(int32_t socket,char * buffer,size_t length)32 ssize_t UeventKernelMulticastRecv(int32_t socket, char *buffer, size_t length)
33 {
34     struct iovec iov = { buffer, length };
35     struct sockaddr_nl addr;
36     char control[CMSG_SPACE(sizeof(struct ucred))];
37     struct msghdr hdr = {
38         .msg_name = &addr,
39         .msg_namelen = sizeof(addr),
40         .msg_iov = &iov,
41         .msg_iovlen = 1,
42         .msg_control = control,
43         .msg_controllen = sizeof(control),
44         .msg_flags = 0,
45     };
46     struct cmsghdr *cmsg;
47 
48     ssize_t n = recvmsg(socket, &hdr, 0);
49     if (n <= 0) {
50         LOGE("Recvmsg failed, errno %{public}d", errno);
51         return n;
52     }
53 
54     if (addr.nl_groups == 0 || addr.nl_pid != 0) {
55         return E_ERR;
56     }
57 
58     cmsg = CMSG_FIRSTHDR(&hdr);
59     if (cmsg == nullptr || cmsg->cmsg_type != SCM_CREDENTIALS) {
60         LOGE("SCM_CREDENTIALS check failed");
61         return E_ERR;
62     }
63 
64     struct ucred cred;
65     if (memcpy_s(&cred, sizeof(cred), CMSG_DATA(cmsg), sizeof(struct ucred)) != EOK || cred.uid != 0) {
66         LOGE("Uid check failed");
67         return E_ERR;
68     }
69 
70     return n;
71 }
72 
RecvUeventMsg()73 void NetlinkListener::RecvUeventMsg()
74 {
75     auto msg = std::make_unique<char[]>(UEVENT_MSG_LEN + 1);
76 
77     while (1) {
78         auto count = UeventKernelMulticastRecv(socketFd_, msg.get(), UEVENT_MSG_LEN);
79         if (count <= 0) {
80             (void)memset_s(msg.get(), UEVENT_MSG_LEN + 1, 0, UEVENT_MSG_LEN + 1);
81             break;
82         }
83         if (count >= UEVENT_MSG_LEN) {
84             continue;
85         }
86 
87         msg.get()[count] = '\0';
88         OnEvent(msg.get());
89     }
90 }
91 
ReadMsg(int32_t fd_count,struct pollfd ufds[2])92 int32_t NetlinkListener::ReadMsg(int32_t fd_count, struct pollfd ufds[2])
93 {
94     int32_t i;
95     for (i = 0; i < fd_count; i++) {
96         if (ufds[i].revents == 0) {
97             continue;
98         }
99 
100         if (ufds[i].fd == socketPipe_[0]) {
101             int32_t msg = 0;
102             if (read(socketPipe_[0], &msg, 1) < 0) {
103                 LOGE("Read socket pipe failed");
104                 return E_ERR;
105             }
106             if (msg == 0) {
107                 LOGI("Stop listener");
108                 return E_ERR;
109             }
110         } else if (ufds[i].fd == socketFd_) {
111             if ((static_cast<uint32_t>(ufds[i].revents) & POLLIN)) {
112                 RecvUeventMsg();
113                 continue;
114             }
115             if ((static_cast<uint32_t>(ufds[i].revents)) & (POLLERR | POLLHUP)) {
116                 LOGE("POLLERR | POLLHUP");
117                 return E_ERR;
118             }
119         }
120     }
121     return E_OK;
122 }
123 
RunListener()124 void NetlinkListener::RunListener()
125 {
126     struct pollfd ufds[2];
127     int32_t idle_time = POLL_IDLE_TIME;
128 
129     while (1) {
130         int32_t fd_count = 0;
131 
132         ufds[fd_count].fd = socketPipe_[0];
133         ufds[fd_count].events = POLLIN;
134         ufds[fd_count].revents = 0;
135         fd_count++;
136 
137         if (socketFd_ > -1) {
138             ufds[fd_count].fd = socketFd_;
139             ufds[fd_count].events = POLLIN;
140             ufds[fd_count].revents = 0;
141             fd_count++;
142         }
143 
144         int32_t n = poll(ufds, fd_count, idle_time);
145         if (n < 0) {
146             if (errno == EAGAIN || errno == EINTR) {
147                 continue;
148             }
149             break;
150         } else if (!n) {
151             continue;
152         }
153 
154         if (ReadMsg(fd_count, ufds) != 0) {
155             return;
156         }
157     }
158 }
159 
EventProcess(void * object)160 void NetlinkListener::EventProcess(void *object)
161 {
162     if (object == nullptr) {
163         return;
164     }
165 
166     NetlinkListener* me = reinterpret_cast<NetlinkListener *>(object);
167     me->RunListener();
168 }
169 
StartListener()170 int32_t NetlinkListener::StartListener()
171 {
172     if (socketFd_ < 0) {
173         return E_ERR;
174     }
175 
176     if (pipe(socketPipe_) == -1) {
177         LOGE("Pipe error");
178         return E_ERR;
179     }
180     socketThread_ = std::make_unique<std::thread>([this]() { this->EventProcess(static_cast<void *>(this)); });
181     if (socketThread_ == nullptr) {
182         (void)close(socketPipe_[0]);
183         (void)close(socketPipe_[1]);
184         socketPipe_[0] = socketPipe_[1] = -1;
185         return E_ERR;
186     }
187 
188     return E_OK;
189 }
190 
StopListener()191 int32_t NetlinkListener::StopListener()
192 {
193     int32_t msg = 0;
194     write(socketPipe_[1], &msg, 1);
195 
196     if (socketThread_ != nullptr && socketThread_->joinable()) {
197         socketThread_->join();
198     }
199 
200     (void)close(socketPipe_[0]);
201     (void)close(socketPipe_[1]);
202     socketPipe_[0] = socketPipe_[1] = -1;
203 
204     return E_OK;
205 }
206 
NetlinkListener(int32_t socket)207 NetlinkListener::NetlinkListener(int32_t socket)
208 {
209     socketFd_ = socket;
210 }
211 } // StorageDaemon
212 } // OHOS
213