1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <asm/unistd.h>
17 #include <fcntl.h>
18 #include <gtest/gtest.h>
19 #include <sched.h>
20 #include <sys/prctl.h>
21 #include <sys/stat.h>
22 #include <sys/syscall.h>
23 #include <sys/types.h>
24 #include <sys/wait.h>
25 #include <syscall.h>
26 #include <unistd.h>
27 
28 #include <cerrno>
29 #include <climits>
30 #include <csignal>
31 #include <cstdlib>
32 #include <cstring>
33 
34 #include "seccomp_policy.h"
35 
36 using SyscallFunc = bool (*)(void);
37 constexpr int SLEEP_TIME_100MS = 100000; // 100ms
38 constexpr int SLEEP_TIME_1S = 1;
39 
40 using namespace testing::ext;
41 using namespace std;
42 
43 namespace OHOS {
44 namespace MiscServices {
45 class SeccompUnitTest : public testing::Test {
46 public:
SeccompUnitTest()47     SeccompUnitTest(){};
~SeccompUnitTest()48     virtual ~SeccompUnitTest(){};
SetUpTestCase()49     static void SetUpTestCase(){};
TearDownTestCase()50     static void TearDownTestCase(){};
51 
SetUp()52     void SetUp()
53     {
54         /*
55          * Wait for 1 second to prevent the generated crash file
56          * from being overwritten because the crash interval is too short
57          * and the crash file's name is constructed by time stamp.
58          */
59         sleep(SLEEP_TIME_1S);
60     };
61 
TearDown()62     void TearDown(){};
TestBody(void)63     void TestBody(void){};
64 
StartChild(SeccompFilterType type,const char * filterName,SyscallFunc func)65     static pid_t StartChild(SeccompFilterType type, const char *filterName, SyscallFunc func)
66     {
67         pid_t pid = fork();
68         if (pid == 0) {
69             if (prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) != 0) {
70                 std::cout << "PR_SET_NO_NEW_PRIVS set fail " << std::endl;
71                 exit(EXIT_FAILURE);
72             }
73 
74             if (!SetSeccompPolicyWithName(type, filterName)) {
75                 std::cout << "SetSeccompPolicy set fail fiterName is " << filterName << std::endl;
76                 exit(EXIT_FAILURE);
77             }
78 
79             if (!func()) {
80                 std::cout << "func excute fail" << std::endl;
81                 exit(EXIT_FAILURE);
82             }
83 
84             std::cout << "func excute success" << std::endl;
85 
86             exit(EXIT_SUCCESS);
87         }
88         return pid;
89     }
90 
CheckStatus(int status,bool isAllow)91     static int CheckStatus(int status, bool isAllow)
92     {
93         if (WEXITSTATUS(status) == EXIT_FAILURE) {
94             return -1;
95         }
96 
97         if (WIFSIGNALED(status)) {
98             if (WTERMSIG(status) == SIGSYS) {
99                 std::cout << "child process exit with SIGSYS" << std::endl;
100                 return isAllow ? -1 : 0;
101             }
102         } else {
103             std::cout << "child process finished normally" << std::endl;
104             return isAllow ? 0 : -1;
105         }
106 
107         return -1;
108     }
109 
CheckSyscall(SeccompFilterType type,const char * filterName,SyscallFunc func,bool isAllow)110     static int CheckSyscall(SeccompFilterType type, const char *filterName, SyscallFunc func, bool isAllow)
111     {
112         sigset_t set;
113         int status;
114         pid_t pid;
115         int flag = 0;
116         struct timespec waitTime = { 5, 0 };
117 
118         sigemptyset(&set);
119         sigaddset(&set, SIGCHLD);
120         sigprocmask(SIG_BLOCK, &set, nullptr);
121         sigaddset(&set, SIGSYS);
122         if (signal(SIGCHLD, SIG_DFL) == nullptr) {
123             std::cout << "signal failed:" << strerror(errno) << std::endl;
124         }
125         if (signal(SIGSYS, SIG_DFL) == nullptr) {
126             std::cout << "signal failed:" << strerror(errno) << std::endl;
127         }
128 
129         /* Sleeping for avoiding influencing child proccess wait for other threads
130          * which were created by other unittests to release global rwlock. The global
131          * rwlock will be used by function dlopen in child process */
132         usleep(SLEEP_TIME_100MS);
133 
134         pid = StartChild(type, filterName, func);
135         if (pid == -1) {
136             std::cout << "fork failed:" << strerror(errno) << std::endl;
137             return -1;
138         }
139         if (sigtimedwait(&set, nullptr, &waitTime) == -1) { /* Wait for 5 seconds */
140             if (errno == EAGAIN) {
141                 flag = 1;
142             } else {
143                 std::cout << "sigtimedwait failed:" << strerror(errno) << std::endl;
144             }
145         }
146 
147         if (waitpid(pid, &status, 0) != pid) {
148             std::cout << "waitpid failed:" << strerror(errno) << std::endl;
149             return -1;
150         }
151 
152         if (flag != 0) {
153             std::cout << "Child process time out" << std::endl;
154         }
155 
156         return CheckStatus(status, isAllow);
157     }
158 
CheckSendfile()159     static bool CheckSendfile()
160     {
161         int ret = syscall(__NR_sendfile, 0, 0, nullptr, 0);
162         if (ret == 0) {
163             return true;
164         }
165 
166         return false;
167     }
168 
CheckVmsplice()169     static bool CheckVmsplice()
170     {
171         int ret = syscall(__NR_vmsplice, 0, nullptr, 0, 0);
172         if (ret == 0) {
173             return true;
174         }
175 
176         return false;
177     }
178 
CheckSocketpair()179     static bool CheckSocketpair()
180     {
181         int ret = syscall(__NR_socketpair, 0, 0, 0, nullptr);
182         if (ret == 0) {
183             return true;
184         }
185 
186         return false;
187     }
188 
CheckListen()189     static bool CheckListen()
190     {
191         int ret = syscall(__NR_listen, 0, 0);
192         if (ret == 0) {
193             return true;
194         }
195 
196         return false;
197     }
198 
CheckAccept()199     static bool CheckAccept()
200     {
201         int ret = syscall(__NR_accept, 0, nullptr, nullptr);
202         if (ret == 0) {
203             return true;
204         }
205 
206         return false;
207     }
208 
CheckAccept4()209     static bool CheckAccept4()
210     {
211         int ret = syscall(__NR_accept4, 0, nullptr, nullptr, 0);
212         if (ret == 0) {
213             return true;
214         }
215 
216         return false;
217     }
218 
CheckGetsockname()219     static bool CheckGetsockname()
220     {
221         int ret = syscall(__NR_getsockname, 0, nullptr, nullptr);
222         if (ret == 0) {
223             return true;
224         }
225 
226         return false;
227     }
228 
CheckGetpeername()229     static bool CheckGetpeername()
230     {
231         int ret = syscall(__NR_getpeername, 0, nullptr, nullptr);
232         if (ret == 0) {
233             return true;
234         }
235 
236         return false;
237     }
238 
CheckShutdown()239     static bool CheckShutdown()
240     {
241         int ret = syscall(__NR_shutdown, 0, 0);
242         if (ret == 0) {
243             return true;
244         }
245 
246         return false;
247     }
248 
CheckSendmsg()249     static bool CheckSendmsg()
250     {
251         int ret = syscall(__NR_sendmsg, 0, nullptr, 0);
252         if (ret == 0) {
253             return true;
254         }
255 
256         return false;
257     }
258 
CheckRecvmmsg()259     static bool CheckRecvmmsg()
260     {
261         int ret = syscall(__NR_recvmmsg, 0, nullptr, 0, 0, nullptr);
262         if (ret == 0) {
263             return true;
264         }
265 
266         return false;
267     }
268 #if defined __aarch64__
CheckSetuid()269     static bool CheckSetuid()
270     {
271         int uid = syscall(__NR_setuid, 1);
272         if (uid == 0) {
273             return true;
274         }
275 
276         return false;
277     }
278 
279 #elif defined __arm__
CheckSetuid32()280     static bool CheckSetuid32()
281     {
282         uid_t uid = syscall(__NR_setuid32, 123);
283         if (uid >= 0) {
284             return true;
285         }
286         return false;
287     }
288 
CheckSendfile64()289     static bool CheckSendfile64()
290     {
291         int ret = syscall(__NR_sendfile64, 0, 0, nullptr, 0);
292         if (ret == 0) {
293             return true;
294         }
295 
296         return false;
297     }
CheckRecvmmsgTime64()298     static bool CheckRecvmmsgTime64()
299     {
300         int ret = syscall(__NR_recvmmsg_time64, 0, nullptr, 0, 0, nullptr);
301         if (ret == 0) {
302             return true;
303         }
304 
305         return false;
306     }
307 #endif
308 
TestInputMethodExtSycall()309     void TestInputMethodExtSycall()
310     {
311         int ret = -1;
312         ret = CheckSyscall(APP, IMF_EXTENTOIN_NAME, CheckSendfile, false);
313         EXPECT_EQ(ret, 0);
314 
315         ret = CheckSyscall(APP, IMF_EXTENTOIN_NAME, CheckVmsplice, false);
316         EXPECT_EQ(ret, 0);
317 
318 #if defined __aarch64__
319         // system blocklist
320         ret = CheckSyscall(APP, IMF_EXTENTOIN_NAME, CheckSetuid, false);
321         EXPECT_EQ(ret, 0);
322 #elif defined __arm__
323         // system blocklist
324         ret = CheckSyscall(APP, IMF_EXTENTOIN_NAME, CheckSetuid32, false);
325         EXPECT_EQ(ret, 0);
326         ret = CheckSyscall(APP, IMF_EXTENTOIN_NAME, CheckSendfile64, false);
327         EXPECT_EQ(ret, 0);
328         ret = CheckSyscall(APP, IMF_EXTENTOIN_NAME, CheckRecvmmsgTime64, false);
329         EXPECT_EQ(ret, 0);
330 #endif
331     }
332 };
333 
334 /**
335  * @tc.name: TestInputMethodExtSeccomp
336  * @tc.desc: Verify the input method extenstion's seccomp policy.
337  * @tc.type: FUNC
338  * @tc.require: issueI9PUAS
339  */
340 HWTEST_F(SeccompUnitTest, TestInputMethodExtSycall, TestSize.Level1)
341 {
342     SeccompUnitTest test;
343     test.TestInputMethodExtSycall();
344 }
345 } // namespace MiscServices
346 } // namespace OHOS
347