1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4#
5# Copyright (c) 2022 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import sys
20import argparse
21import textwrap
22import re
23import os
24import stat
25
26supported_parse_item = ['labelName', 'priority', 'allowList', 'blockList', 'priorityWithArgs', \
27                        'allowListWithArgs', 'headFiles', 'selfDefineSyscall', 'returnValue', \
28                        'mode', 'privilegedProcessName', 'allowBlockList']
29
30supported_architecture = ['arm', 'arm64', 'riscv64']
31
32BPF_JGE = 'BPF_JUMP(BPF_JMP|BPF_JGE|BPF_K, {}, {}, {}),'
33BPF_JGT = 'BPF_JUMP(BPF_JMP|BPF_JGT|BPF_K, {}, {}, {}),'
34BPF_JEQ = 'BPF_JUMP(BPF_JMP|BPF_JEQ|BPF_K, {}, {}, {}),'
35BPF_JSET = 'BPF_JUMP(BPF_JMP|BPF_JSET|BPF_K, {}, {}, {}),'
36BPF_JA = 'BPF_JUMP(BPF_JMP|BPF_JA, {}, 0, 0),'
37BPF_LOAD = 'BPF_STMT(BPF_LD|BPF_W|BPF_ABS, {}),'
38BPF_LOAD_MEM = 'BPF_STMT(BPF_LD|BPF_MEM, {}),'
39BPF_ST = 'BPF_STMT(BPF_ST, {}),'
40BPF_AND = 'BPF_STMT(BPF_ALU|BPF_AND|BPF_K, {}),'
41BPF_RET_VALUE = 'BPF_STMT(BPF_RET|BPF_K, {}),'
42
43operation = ['<', '<=', '!=', '==', '>', '>=', '&']
44
45ret_str_to_bpf = {
46    'KILL_PROCESS': 'SECCOMP_RET_KILL_PROCESS',
47    'KILL_THREAD': 'SECCOMP_RET_KILL_THREAD',
48    'TRAP': 'SECCOMP_RET_TRAP',
49    'ERRNO': 'SECCOMP_RET_ERRNO',
50    'USER_NOTIF': 'SECCOMP_RET_USER_NOTIF',
51    'TRACE': 'SECCOMP_RET_TRACE',
52    'LOG' : 'SECCOMP_RET_LOG',
53    'ALLOW': 'SECCOMP_RET_ALLOW'
54}
55
56mode_str = {
57    'DEFAULT': 0,
58    'ONLY_CHECK_ARGS': 1
59}
60
61architecture_to_number = {
62    'arm': 'AUDIT_ARCH_ARM',
63    'arm64': 'AUDIT_ARCH_AARCH64',
64    'riscv64': 'AUDIT_ARCH_RISCV64'
65}
66
67
68class ValidateError(Exception):
69    def __init__(self, msg):
70        super().__init__(msg)
71
72
73def print_info(info):
74    print("[INFO] %s" % info)
75
76
77def is_hex_digit(s):
78    try:
79        int(s, 16)
80        return True
81
82    except ValueError:
83        return False
84
85
86def str_convert_to_int(s):
87    number = -1
88    digit_flag = False
89
90    if s.isdigit() :
91        number = int(s)
92        digit_flag = True
93
94    elif is_hex_digit(s):
95        number = int(s, 16)
96        digit_flag = True
97
98    return number, digit_flag
99
100
101def is_function_name_exist(arch, function_name, func_name_nr_table):
102    if function_name in func_name_nr_table:
103        return True
104    else:
105        raise ValidateError('{} not exsit in {} function_name_nr_table Table'.format(function_name, arch))
106
107
108def is_errno_in_valid_range(errno):
109    if int(errno) > 0 and int(errno) <= 255 and errno.isdigit():
110        return True
111    else:
112        raise ValidateError('{} not within the legal range of errno values.'.format(errno))
113
114
115def is_return_errno(return_str):
116    if return_str[0:len('ERRNO')] == 'ERRNO':
117        errno_no = return_str[return_str.find('(') + 1 : return_str.find(')')]
118        return_string = return_str[0:len('ERRNO')]
119        return_string += ' | '
120        if is_errno_in_valid_range(errno_no):
121            return_string += errno_no
122            return True, return_string
123    return False, 'not_return_errno'
124
125
126def function_name_to_nr(function_name_list, func_name_nr_table):
127    return set(func_name_nr_table[function_name] for function_name \
128    in function_name_list if function_name in func_name_nr_table)
129
130
131def filter_syscalls_nr(name_to_nr):
132    syscalls = {}
133    for syscall_name, nr in name_to_nr.items():
134        if not syscall_name.startswith("__NR_") and not syscall_name.startswith("__ARM_NR_"):
135            continue
136
137        if syscall_name.startswith("__NR_arm_"):
138            syscall_name = syscall_name[len("__NR_arm_"):]
139        elif syscall_name.startswith("__NR_riscv_"):
140            syscall_name = syscall_name[len("__NR_riscv_"):]
141        elif syscall_name.startswith("__NR_"):
142            syscall_name = syscall_name[len("__NR_"):]
143        elif syscall_name.startswith("__ARM_NR_"):
144            syscall_name = syscall_name[len("__ARM_NR_"):]
145        elif syscall_name.startswith("__RISCV_NR_"):
146            syscall_name = syscall_name[len("__RISCV_NR_"):]
147        syscalls[syscall_name] = nr
148
149    return syscalls
150
151
152def parse_syscall_file(file_name):
153    const_pattern = re.compile(
154        r'^\s*#define\s+([A-Za-z_][A-Za-z0-9_]+)\s+(.+)\s*$')
155    mark_pattern = re.compile(r'\b[A-Za-z_][A-Za-z0-9_]+\b')
156    name_to_nr = {}
157    with open(file_name) as f:
158        for line in f:
159            k = const_pattern.match(line)
160            if k is None:
161                continue
162            try:
163                name = k.group(1)
164                nr = eval(mark_pattern.sub(lambda x: str(name_to_nr.get(x.group(0))),
165                                        k.group(2)))
166
167                name_to_nr[name] = nr
168            except(KeyError, SyntaxError, NameError, TypeError):
169                continue
170
171    return filter_syscalls_nr(name_to_nr)
172
173
174def gen_syscall_nr_table(file_name, func_name_nr_table):
175    s = re.search(r"libsyscall_to_nr_([^/]+)", file_name)
176    func_name_nr_table[str(s.group(1))] = parse_syscall_file(file_name)
177    if str(s.group(1)) not in func_name_nr_table.keys():
178        raise ValidateError("parse syscall file failed")
179    return func_name_nr_table
180
181
182class SeccompPolicyParam:
183    def __init__(self, arch, function_name_nr_table, is_debug):
184        self.arch = arch
185        self.priority = set()
186        self.allow_list = set()
187        self.blocklist = set()
188        self.priority_with_args = set()
189        self.allow_list_with_args = set()
190        self.head_files = set()
191        self.self_define_syscall = set()
192        self.final_allow_list = set()
193        self.final_priority = set()
194        self.final_priority_with_args = set()
195        self.final_allow_list_with_args = set()
196        self.return_value = ''
197        self.mode = 'DEFAULT'
198        self.is_debug = is_debug
199        self.function_name_nr_table = function_name_nr_table
200        self.value_function = {
201            'priority': self.update_priority,
202            'allowList': self.update_allow_list,
203            'blockList': self.update_blocklist,
204            'allowListWithArgs': self.update_allow_list_with_args,
205            'priorityWithArgs': self.update_priority_with_args,
206            'headFiles': self.update_head_files,
207            'selfDefineSyscall': self.update_self_define_syscall,
208            'returnValue': self.update_return_value,
209            'mode': self.update_mode
210        }
211
212    def clear_list(self):
213        self.priority.clear()
214        self.allow_list.clear()
215        self.allow_list_with_args.clear()
216        self.priority_with_args.clear()
217        if self.mode == 'ONLY_CHECK_ARGS':
218            self.final_allow_list.clear()
219            self.final_priority.clear()
220
221    def update_list(self, function_name, to_update_list):
222        if is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
223            to_update_list.add(function_name)
224            return True
225        return False
226
227    def update_priority(self, function_name):
228        return self.update_list(function_name, self.priority)
229
230    def update_allow_list(self, function_name):
231        return self.update_list(function_name, self.allow_list)
232
233    def update_blocklist(self, function_name):
234        return self.update_list(function_name, self.blocklist)
235
236    def update_priority_with_args(self, function_name_with_args):
237        function_name = function_name_with_args[:function_name_with_args.find(':')]
238        function_name = function_name.strip()
239        if is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
240            self.priority_with_args.add(function_name_with_args)
241            return True
242        return False
243
244    def update_allow_list_with_args(self, function_name_with_args):
245        function_name = function_name_with_args[:function_name_with_args.find(':')]
246        function_name = function_name.strip()
247        if is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
248            self.allow_list_with_args.add(function_name_with_args)
249            return True
250        return False
251
252    def update_head_files(self, head_files):
253        if len(head_files) > 2 and (head_files[0] == '\"' and head_files[-1] == '\"') or \
254            (head_files[0] == '<' and head_files[-1] == '>'):
255            self.head_files.add(head_files)
256            return True
257
258        raise ValidateError('{} is not legal by headFiles format'.format(head_files))
259
260    def update_self_define_syscall(self, self_define_syscall):
261        nr, digit_flag = str_convert_to_int(self_define_syscall)
262        if digit_flag and nr not in self.function_name_nr_table.values():
263            self.self_define_syscall.add(nr)
264            return True
265
266        raise ValidateError('{} is not a number or {} is already used by ohter \
267            syscall'.format(self_define_syscall, self_define_syscall))
268
269    def update_return_value(self, return_str):
270        is_ret_errno, return_string = is_return_errno(return_str)
271        if is_ret_errno == True:
272            self.return_value = return_string
273            return True
274        if return_str in ret_str_to_bpf:
275            if self.is_debug == 'false' and return_str == 'LOG':
276                raise ValidateError("LOG return value is not allowed in user mode")
277            self.return_value = return_str
278            return True
279
280        raise ValidateError('{} not in {}'.format(return_str, ret_str_to_bpf.keys()))
281
282    def update_mode(self, mode):
283        if mode in mode_str.keys():
284            self.mode = mode
285            return True
286        raise ValidateError('{} not in [DEFAULT, ONLY_CHECK_ARGS]'.format(mode_str))
287
288    def check_allow_list(self, allow_list):
289        for item in allow_list:
290            pos = item.find(':')
291            syscall = item
292            if pos != -1:
293                syscall = item[:pos]
294            if syscall in self.blocklist:
295                raise ValidateError('{} of allow list  is in block list'.format(syscall))
296        return True
297
298    def check_all_allow_list(self):
299        flag = self.check_allow_list(self.final_allow_list) \
300               and self.check_allow_list(self.final_priority) \
301               and self.check_allow_list(self.final_priority_with_args) \
302               and self.check_allow_list(self.final_allow_list_with_args)
303        block_nr_list = function_name_to_nr(self.blocklist, self.function_name_nr_table)
304        for nr in self.self_define_syscall:
305            if nr in block_nr_list:
306                return False
307        return flag
308
309    def update_final_list(self):
310        #remove duplicate function_name
311        self.final_allow_list |= self.allow_list
312        self.final_priority |= self.priority
313        self.final_allow_list_with_args |= self.allow_list_with_args
314        self.final_priority_with_args |= self.priority_with_args
315        final_priority_function_name_list_with_args = set(item[:item.find(':')]
316                                                            for item in self.final_priority_with_args)
317        final_function_name_list_with_args = set(item[:item.find(':')]
318                                                    for item in self.final_allow_list_with_args)
319        self.final_allow_list = self.final_allow_list - self.final_priority - \
320                                    final_priority_function_name_list_with_args - final_function_name_list_with_args
321        self.final_priority = self.final_priority - final_priority_function_name_list_with_args - \
322                                final_function_name_list_with_args
323        self.clear_list()
324
325
326class GenBpfPolicy:
327    def __init__(self):
328        self.arch = ''
329        self.syscall_nr_range = []
330        self.bpf_policy = []
331        self.syscall_nr_policy_list = []
332        self.function_name_nr_table_dict = {}
333        self.gen_mode = 0
334        self.flag = True
335        self.return_value = ''
336        self.operate_func_table = {
337            '<' : self.gen_bpf_lt,
338            '<=': self.gen_bpf_le,
339            '==': self.gen_bpf_eq,
340            '!=': self.gen_bpf_ne,
341            '>' : self.gen_bpf_gt,
342            '>=': self.gen_bpf_ge,
343            '&' : self.gen_bpf_set,
344        }
345
346    @staticmethod
347    def gen_bpf_eq32(const_str, jt, jf):
348        bpf_policy = []
349        bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
350        return bpf_policy
351
352    @staticmethod
353    def gen_bpf_eq64(const_str, jt, jf):
354        bpf_policy = []
355        bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
356        bpf_policy.append(BPF_LOAD_MEM.format(0))
357        bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
358        return bpf_policy
359
360    @staticmethod
361    def gen_bpf_gt32(const_str, jt, jf):
362        bpf_policy = []
363        bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
364        return bpf_policy
365
366    @staticmethod
367    def gen_bpf_gt64(const_str, jt, jf):
368        bpf_policy = []
369        number, digit_flag = str_convert_to_int(const_str)
370
371        hight = int(number / (2**32))
372        low = number & 0xffffffff
373
374        if digit_flag and hight == 0:
375            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
376        else:
377            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
378            bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
379
380        bpf_policy.append(BPF_LOAD_MEM.format(0))
381        bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
382
383        return bpf_policy
384
385    @staticmethod
386    def gen_bpf_ge32(const_str, jt, jf):
387        bpf_policy = []
388        bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
389        return bpf_policy
390
391    @staticmethod
392    def gen_bpf_ge64(const_str, jt, jf):
393        bpf_policy = []
394        number, digit_flag = str_convert_to_int(const_str)
395
396        hight = int(number / (2**32))
397        low = number & 0xffffffff
398
399        if digit_flag and hight == 0:
400            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
401        else:
402            bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
403            bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
404        bpf_policy.append(BPF_LOAD_MEM.format(0))
405        bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
406        return bpf_policy
407
408    @staticmethod
409    def gen_bpf_set32(const_str, jt, jf):
410        bpf_policy = []
411        bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
412        return bpf_policy
413
414    @staticmethod
415    def gen_bpf_set64(const_str, jt, jf):
416        bpf_policy = []
417        bpf_policy.append(BPF_JSET.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
418        bpf_policy.append(BPF_LOAD_MEM.format(0))
419        bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
420        return bpf_policy
421
422    @staticmethod
423    def gen_bpf_valid_syscall_nr(syscall_nr, cur_size):
424        bpf_policy = []
425        bpf_policy.append(BPF_LOAD.format(0))
426        bpf_policy.append(BPF_JEQ.format(syscall_nr, 0, cur_size))
427        return bpf_policy
428
429    @staticmethod
430    def check_arg_str(arg_atom):
431        arg_str = arg_atom[0:3]
432        if arg_str != 'arg':
433            raise ValidateError('format ERROR, {} is not equal to arg'.format(arg_atom))
434
435        arg_id = int(arg_atom[3])
436        if arg_id not in range(6):
437            raise ValidateError('arg num out of the scope 0~5')
438
439        return arg_id, True
440
441    @staticmethod
442    def check_operation_str(operation_atom):
443        operation_str = operation_atom
444        if operation_str not in operation:
445            operation_str = operation_atom[0]
446            if operation_str not in operation:
447                raise ValidateError('operation not in [<, <=, !=, ==, >, >=, &]')
448        return operation_str, True
449
450    #gen bpf (argn & mask) == value
451    @staticmethod
452    def gen_mask_equal_bpf(arg_id, mask, value, cur_size):
453        bpf_policy = []
454        #high 4 bytes
455        bpf_policy.append(BPF_LOAD.format(20 + arg_id * 8))
456        bpf_policy.append(BPF_AND.format('((uint64_t)' + mask + ') >> 32'))
457        bpf_policy.append(BPF_JEQ.format('((uint64_t)' + value + ') >> 32', 0, cur_size + 4))
458
459        #low 4 bytes
460        bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
461        bpf_policy.append(BPF_AND.format(mask))
462        bpf_policy.append(BPF_JEQ.format(value, cur_size, cur_size + 1))
463
464        return bpf_policy
465
466    def update_arch(self, arch):
467        self.arch = arch
468        self.syscall_nr_range = []
469        self.syscall_nr_policy_list = []
470
471    def update_function_name_nr_table(self, func_name_nr_table):
472        self.function_name_nr_table_dict = func_name_nr_table
473
474    def clear_bpf_policy(self):
475        self.bpf_policy.clear()
476
477    def get_gen_flag(self):
478        return self.flag
479
480    def set_gen_flag(self, flag):
481        if flag:
482            self.flag = True
483        else:
484            self.flag = False
485
486    def set_gen_mode(self, mode):
487        self.gen_mode = mode_str.get(mode)
488
489    def set_return_value(self, return_value):
490        is_ret_errno, return_string = is_return_errno(return_value)
491        if is_ret_errno == True:
492            self.return_value = return_string
493            return
494        if return_value not in ret_str_to_bpf:
495            self.set_gen_mode(False)
496            return
497
498        self.return_value = return_value
499
500    def gen_bpf_eq(self, const_str, jt, jf):
501        if self.arch == 'arm':
502            return self.gen_bpf_eq32(const_str, jt, jf)
503        elif self.arch == 'arm64' or self.arch == 'riscv64':
504            return self.gen_bpf_eq64(const_str, jt, jf)
505        return []
506
507    def gen_bpf_ne(self, const_str, jt, jf):
508        return self.gen_bpf_eq(const_str, jf, jt)
509
510    def gen_bpf_gt(self, const_str, jt, jf):
511        if self.arch == 'arm':
512            return self.gen_bpf_gt32(const_str, jt, jf)
513        elif self.arch == 'arm64' or self.arch == 'riscv64':
514            return self.gen_bpf_gt64(const_str, jt, jf)
515        return []
516
517    def gen_bpf_le(self, const_str, jt, jf):
518        return self.gen_bpf_gt(const_str, jf, jt)
519
520    def gen_bpf_ge(self, const_str, jt, jf):
521        if self.arch == 'arm':
522            return self.gen_bpf_ge32(const_str, jt, jf)
523        elif self.arch == 'arm64' or self.arch == 'riscv64':
524            return self.gen_bpf_ge64(const_str, jt, jf)
525        return []
526
527    def gen_bpf_lt(self, const_str, jt, jf):
528        return self.gen_bpf_ge(const_str, jf, jt)
529
530    def gen_bpf_set(self, const_str, jt, jf):
531        if self.arch == 'arm':
532            return self.gen_bpf_set32(const_str, jt, jf)
533        elif self.arch == 'arm64' or self.arch == 'riscv64':
534            return self.gen_bpf_set64(const_str, jt, jf)
535        return []
536
537    def gen_range_list(self, syscall_nr_list):
538        if len(syscall_nr_list) == 0:
539            return
540        self.syscall_nr_range.clear()
541
542        syscall_nr_list_order = sorted(list(syscall_nr_list))
543        range_temp = [syscall_nr_list_order[0], syscall_nr_list_order[0]]
544
545        for i in range(len(syscall_nr_list_order) - 1):
546            if syscall_nr_list_order[i + 1] != syscall_nr_list_order[i] + 1:
547                range_temp[1] = syscall_nr_list_order[i]
548                self.syscall_nr_range.append(range_temp)
549                range_temp = [syscall_nr_list_order[i + 1], syscall_nr_list_order[i + 1]]
550
551        range_temp[1] = syscall_nr_list_order[-1]
552        self.syscall_nr_range.append(range_temp)
553
554    def gen_policy_syscall_nr(self, min_index, max_index, cur_syscall_nr_range):
555        middle_index = (int)((min_index + max_index + 1) / 2)
556
557        if middle_index == min_index:
558            self.syscall_nr_policy_list.append(cur_syscall_nr_range[middle_index][1] + 1)
559            return
560        else:
561            self.syscall_nr_policy_list.append(cur_syscall_nr_range[middle_index][0])
562
563        self.gen_policy_syscall_nr(min_index, middle_index - 1, cur_syscall_nr_range)
564        self.gen_policy_syscall_nr(middle_index, max_index, cur_syscall_nr_range)
565
566    def gen_policy_syscall_nr_list(self, cur_syscall_nr_range):
567        if not cur_syscall_nr_range:
568            return
569        self.syscall_nr_policy_list.clear()
570        self.syscall_nr_policy_list.append(cur_syscall_nr_range[0][0])
571        self.gen_policy_syscall_nr(0, len(cur_syscall_nr_range) - 1, cur_syscall_nr_range)
572
573    def calculate_step(self, index):
574        for i in range(index + 1, len(self.syscall_nr_policy_list)):
575            if self.syscall_nr_policy_list[index] < self.syscall_nr_policy_list[i]:
576                step = i - index
577                break
578        return step - 1
579
580    def nr_range_to_bpf_policy(self, cur_syscall_nr_range):
581        self.gen_policy_syscall_nr_list(cur_syscall_nr_range)
582        syscall_list_len = len(self.syscall_nr_policy_list)
583
584        if syscall_list_len == 0:
585            return
586
587        self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[0], 0, syscall_list_len))
588
589        range_max_list = [k[1] for k in cur_syscall_nr_range]
590
591        for i in range(1, syscall_list_len):
592            if self.syscall_nr_policy_list[i] - 1 in range_max_list:
593                self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[i], \
594                                        syscall_list_len - i, syscall_list_len - i - 1))
595            else:
596                step = self.calculate_step(i)
597                self.bpf_policy.append(BPF_JGE.format(self.syscall_nr_policy_list[i], step, 0))
598
599        if self.syscall_nr_policy_list:
600            self.bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_ALLOW'))
601
602    def count_alone_range(self):
603        cnt = 0
604        for item in self.syscall_nr_range:
605            if item[0] == item[1]:
606                cnt = cnt + 1
607        return cnt
608
609    def gen_transverse_bpf_policy(self):
610        if not self.syscall_nr_range:
611            return
612        cnt = self.count_alone_range()
613        total_instruction_num = cnt + (len(self.syscall_nr_range) - cnt) * 2
614        i = 0
615        for item in self.syscall_nr_range:
616            if item[0] == item[1]:
617                if i == total_instruction_num - 1:
618                    self.bpf_policy.append(BPF_JEQ.format(item[0], total_instruction_num - i - 1, 1))
619                else:
620                    self.bpf_policy.append(BPF_JEQ.format(item[0], total_instruction_num - i - 1, 0))
621                i += 1
622            else:
623                self.bpf_policy.append(BPF_JGE.format(item[0], 0, total_instruction_num - i))
624                i += 1
625                if i == total_instruction_num - 1:
626                    self.bpf_policy.append(BPF_JGE.format(item[1] + 1, 1, total_instruction_num - i - 1))
627                else:
628                    self.bpf_policy.append(BPF_JGE.format(item[1] + 1, 0, total_instruction_num - i - 1))
629                i += 1
630
631        self.bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_ALLOW'))
632
633    def gen_bpf_policy(self, syscall_nr_list):
634        self.gen_range_list(syscall_nr_list)
635        range_size = (int)((len(self.syscall_nr_range) - 1) / 127) + 1
636        alone_range_cnt = self.count_alone_range()
637        if alone_range_cnt == len(self.syscall_nr_range):
638            #Scattered distribution
639            self.gen_transverse_bpf_policy()
640            return
641
642        if range_size == 1:
643            self.nr_range_to_bpf_policy(self.syscall_nr_range)
644        else:
645            for i in range(0, range_size):
646                if i == 0:
647                    self.nr_range_to_bpf_policy(self.syscall_nr_range[-127 * (i + 1):])
648                elif i == range_size - 1:
649                    self.nr_range_to_bpf_policy(self.syscall_nr_range[:-127 * i])
650                else:
651                    self.nr_range_to_bpf_policy(self.syscall_nr_range[-127 * (i + 1): -127 * i])
652
653    def load_arg(self, arg_id):
654        # little endian
655        bpf_policy = []
656        if self.arch == 'arm':
657            bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
658        elif self.arch == 'arm64' or self.arch == 'riscv64':
659            #low 4 bytes
660            bpf_policy.append(BPF_LOAD.format(16 + arg_id * 8))
661            bpf_policy.append(BPF_ST.format(0))
662            #high 4 bytes
663            bpf_policy.append(BPF_LOAD.format(20 + arg_id * 8))
664            bpf_policy.append(BPF_ST.format(1))
665
666        return bpf_policy
667
668    def compile_atom(self, atom, cur_size):
669        bpf_policy = []
670        if len(atom) < 6:
671            raise ValidateError('{} format ERROR '.format(atom))
672
673        if atom[0] == '(':
674            bpf_policy += self.compile_mask_equal_atom(atom, cur_size)
675        else:
676            bpf_policy += self.compile_single_operation_atom(atom, cur_size)
677
678        return bpf_policy
679
680    def compile_mask_equal_atom(self, atom, cur_size):
681        bpf_policy = []
682        left_brace_pos = atom.find('(')
683        right_brace_pos = atom.rfind(')')
684        inside_brace_content = atom[left_brace_pos + 1: right_brace_pos]
685        outside_brace_content = atom[right_brace_pos + 1:]
686
687        arg_res = self.check_arg_str(inside_brace_content[0:4])
688        if not arg_res[1]:
689            return bpf_policy
690
691        operation_res_inside = self.check_operation_str(inside_brace_content[4:6])
692        if operation_res_inside[0] != '&' or not operation_res_inside[1]:
693            return bpf_policy
694
695        mask = inside_brace_content[4 + len(operation_res_inside[0]):]
696
697        operation_res_outside = self.check_operation_str(outside_brace_content[0:2])
698        if operation_res_outside[0] != '==' or not operation_res_outside[1]:
699            return bpf_policy
700
701        value = outside_brace_content[len(operation_res_outside[0]):]
702
703        return self.gen_mask_equal_bpf(arg_res[0], mask, value, cur_size)
704
705    def compile_single_operation_atom(self, atom, cur_size):
706        bpf_policy = []
707        arg_res = self.check_arg_str(atom[0:4])
708        if not arg_res[1]:
709            return bpf_policy
710
711        operation_res = self.check_operation_str(atom[4:6])
712        if not operation_res[1]:
713            return bpf_policy
714
715        const_str = atom[4 + len(operation_res[0]):]
716
717        if not const_str:
718            return bpf_policy
719
720        bpf_policy += self.load_arg(arg_res[0])
721        bpf_policy += self.operate_func_table.get(operation_res[0])(const_str, 0, cur_size + 1)
722
723        return bpf_policy
724
725    def parse_args_with_condition(self, group):
726        #the priority of && higher than ||
727        atoms = group.split('&&')
728        bpf_policy = []
729        for atom in reversed(atoms):
730            bpf_policy = self.compile_atom(atom, len(bpf_policy)) + bpf_policy
731        return bpf_policy
732
733    def parse_sub_group(self, group):
734        bpf_policy = []
735        group_info = group.split(';')
736        operation_part = group_info[0]
737        return_part = group_info[1]
738        if not return_part.startswith('return'):
739            raise ValidateError('allow list with args do not have return part')
740
741        self.set_return_value(return_part[len('return'):])
742        and_cond_groups = operation_part.split('||')
743        for and_condition_group in and_cond_groups:
744            bpf_policy += self.parse_args_with_condition(and_condition_group)
745            bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(self.return_value)))
746        return bpf_policy
747
748    def parse_else_part(self, else_part):
749        return_value = else_part.split(';')[0][else_part.find('return') + len('return'):]
750        self.set_return_value(return_value)
751
752    def parse_args(self, function_name, line, skip):
753        bpf_policy = []
754        group_info = line.split('else')
755        else_part = group_info[-1]
756        group = group_info[0].split('elif')
757        for sub_group in group:
758            bpf_policy += self.parse_sub_group(sub_group)
759        self.parse_else_part(else_part)
760        if self.return_value[0:len('ERRNO')] == 'ERRNO':
761            bpf_policy.append(BPF_RET_VALUE.format(self.return_value.replace('ERRNO', ret_str_to_bpf.get('ERRNO'))))
762        else:
763            bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(self.return_value)))
764        syscall_nr = self.function_name_nr_table_dict.get(self.arch).get(function_name)
765        #load syscall nr
766        bpf_policy = self.gen_bpf_valid_syscall_nr(syscall_nr, len(bpf_policy) - skip) + bpf_policy
767        return bpf_policy
768
769    def gen_bpf_policy_with_args(self, allow_list_with_args, mode, return_value):
770        self.set_gen_mode(mode)
771        skip = 0
772        for line in allow_list_with_args:
773            if self.gen_mode == 1 and line == list(allow_list_with_args)[-1]:
774                skip = 2
775            line = line.replace(' ', '')
776            pos = line.find(':')
777            function_name = line[:pos]
778
779            left_line = line[pos + 1:]
780            if not left_line.startswith('if'):
781                continue
782
783            self.bpf_policy += self.parse_args(function_name, left_line[2:], skip)
784
785    def add_load_syscall_nr(self):
786        self.bpf_policy.append(BPF_LOAD.format(0))
787
788    def add_return_value(self, return_value):
789        if return_value[0:len('ERRNO')] == 'ERRNO':
790            self.bpf_policy.append(BPF_RET_VALUE.format(return_value.replace('ERRNO', ret_str_to_bpf.get('ERRNO'))))
791        else:
792            self.bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(return_value)))
793
794    def add_validate_arch(self, arches, skip_step):
795        if not self.bpf_policy or not self.flag:
796            return
797        bpf_policy = []
798        #load arch
799        bpf_policy.append(BPF_LOAD.format(4))
800        if len(arches) == 2:
801            bpf_policy.append(BPF_JEQ.format(architecture_to_number.get(arches[0]), 3, 0))
802            bpf_policy.append(BPF_JEQ.format(architecture_to_number.get(arches[1]), 0, 1))
803            bpf_policy.append(BPF_JA.format(skip_step))
804            bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_TRAP'))
805        elif len(arches) == 1:
806            bpf_policy.append(BPF_JEQ.format(architecture_to_number.get(arches[0]), 1, 0))
807            bpf_policy.append(BPF_RET_VALUE.format('SECCOMP_RET_TRAP'))
808        else:
809            self.bpf_policy = []
810
811        self.bpf_policy = bpf_policy + self.bpf_policy
812
813
814class AllowBlockList:
815    def __init__(self, filter_name, arch, function_name_nr_table):
816        self.is_valid = False
817        self.arch = arch
818        self.filter_name = filter_name
819        self.reduced_block_list = set()
820        self.function_name_nr_table = function_name_nr_table
821        self.value_function = {
822            'privilegedProcessName': self.update_flag,
823            'allowBlockList': self.update_reduced_block_list,
824        }
825
826    def update_flag(self, name):
827        if self.filter_name == name:
828            self.is_valid = True
829        else:
830            self.is_valid = False
831
832    def update_reduced_block_list(self, function_name):
833        if self.is_valid and is_function_name_exist(self.arch, function_name, self.function_name_nr_table):
834            self.reduced_block_list.add(function_name)
835            return True
836        return False
837
838
839class SeccompPolicyParser:
840    def __init__(self):
841        self.cur_parse_item = ''
842        self.arches = set()
843        self.bpf_generator = GenBpfPolicy()
844        self.seccomp_policy_param = dict()
845        self.reduced_block_list_parm = dict()
846        self.key_process_flag = False
847        self.is_debug = False
848
849    def update_is_debug(self, is_debug):
850        if is_debug == 'false':
851            self.is_debug = False
852        else:
853            self.is_debug = True
854
855    def update_arch(self, target_cpu):
856        if target_cpu == "arm":
857            self.arches.add(target_cpu)
858        elif target_cpu == "arm64":
859            self.arches.add("arm")
860            self.arches.add(target_cpu)
861        elif target_cpu == "riscv64":
862            self.arches.add(target_cpu)
863
864    def update_block_list(self):
865        for arch in supported_architecture:
866            self.seccomp_policy_param.get(arch).blocklist -= self.reduced_block_list_parm.get(arch).reduced_block_list
867
868    def update_parse_item(self, line):
869        item = line[1:]
870        if item in supported_parse_item:
871            self.cur_parse_item = item
872            print_info('start deal with {}'.format(self.cur_parse_item))
873
874    def check_allow_list(self):
875        for arch in self.arches:
876            if not self.seccomp_policy_param.get(arch).check_all_allow_list():
877                self.bpf_generator.set_gen_flag(False)
878
879    def clear_file_syscall_list(self):
880        for arch in self.arches:
881            self.seccomp_policy_param.get(arch).update_final_list()
882        self.cur_parse_item = ''
883        self.cur_arch = ''
884
885    def parse_line(self, line):
886        if not self.cur_parse_item :
887            return
888        line = line.replace(' ', '')
889        pos = line.rfind(';')
890        if pos < 0:
891            for arch in self.arches:
892                if self.key_process_flag:
893                    self.reduced_block_list_parm.get(arch).value_function.get(self.cur_parse_item)(line)
894                else:
895                    self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line)
896        else:
897            arches = line[pos + 1:].split(',')
898            if arches[0] == 'all':
899                arches = supported_architecture
900            for arch in arches:
901                if self.key_process_flag:
902                    self.reduced_block_list_parm.get(arch).value_function.get(self.cur_parse_item)(line[:pos])
903                else:
904                    self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line[:pos])
905
906    def parse_open_file(self, fp):
907        for line in fp:
908            line = line.strip()
909            if not line:
910                continue
911            if line[0] == '#':
912                continue
913            if line[0] == '@':
914                self.update_parse_item(line)
915                continue
916            if line[0] != '@' and self.cur_parse_item == '':
917                continue
918            self.parse_line(line)
919        self.clear_file_syscall_list()
920        self.check_allow_list()
921
922    def parse_file(self, file_path):
923        with open(file_path) as fp:
924            self.parse_open_file(fp)
925
926    def gen_seccomp_policy_of_arch(self, arch):
927        cur_policy_param = self.seccomp_policy_param.get(arch)
928
929        if not cur_policy_param.return_value:
930            raise ValidateError('return value not defined')
931
932        #get final allow_list
933        syscall_nr_allow_list = function_name_to_nr(cur_policy_param.final_allow_list, \
934                                                    cur_policy_param.function_name_nr_table) \
935                                                    | cur_policy_param.self_define_syscall
936        syscall_nr_priority = function_name_to_nr(cur_policy_param.final_priority, \
937                                                  cur_policy_param.function_name_nr_table)
938        self.bpf_generator.update_arch(arch)
939
940        #load syscall nr
941        if syscall_nr_allow_list or syscall_nr_priority:
942            self.bpf_generator.add_load_syscall_nr()
943        self.bpf_generator.gen_bpf_policy(syscall_nr_priority)
944        self.bpf_generator.gen_bpf_policy_with_args(sorted(list(cur_policy_param.final_priority_with_args)), \
945            cur_policy_param.mode, cur_policy_param.return_value)
946        self.bpf_generator.gen_bpf_policy(syscall_nr_allow_list)
947        self.bpf_generator.gen_bpf_policy_with_args(sorted(list(cur_policy_param.final_allow_list_with_args)), \
948            cur_policy_param.mode, cur_policy_param.return_value)
949
950        self.bpf_generator.add_return_value(cur_policy_param.return_value)
951        for line in self.bpf_generator.bpf_policy:
952            if 'SECCOMP_RET_LOG' in line and self.is_debug == False:
953                raise ValidateError("LOG return value is not allowed in user mode")
954
955    def gen_seccomp_policy(self):
956        arches = sorted(list(self.arches))
957        if not arches:
958            return
959        self.gen_seccomp_policy_of_arch(arches[0])
960        skip_step = len(self.bpf_generator.bpf_policy) + 1
961        if len(arches) == 2:
962            self.gen_seccomp_policy_of_arch(arches[1])
963
964        self.bpf_generator.add_validate_arch(arches, skip_step)
965
966    def gen_output_file(self, args):
967        if not self.bpf_generator.bpf_policy:
968            raise ValidateError("bpf_policy is empty!")
969
970        header = textwrap.dedent('''\
971
972            #include <linux/filter.h>
973            #include <stddef.h>
974            #include <linux/seccomp.h>
975            #include <linux/audit.h>
976            ''')
977        extra_header = set()
978        for arch in self.arches:
979            extra_header |= self.seccomp_policy_param.get(arch).head_files
980        extra_header_list = ['#include ' + i for i in sorted(list(extra_header))]
981        filter_name = 'g_' + args.filter_name + 'SeccompFilter'
982
983        array_name = textwrap.dedent('''
984
985            const struct sock_filter {}[] = {{
986            ''').format(filter_name)
987
988        footer = textwrap.dedent('''\
989
990            }};
991
992            const size_t {} = sizeof({}) / sizeof(struct sock_filter);
993            ''').format(filter_name + 'Size', filter_name)
994
995        content = header + '\n'.join(extra_header_list) + array_name + \
996            '    ' + '\n    '.join(self.bpf_generator.bpf_policy) + footer
997
998        flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
999        modes = stat.S_IWUSR | stat.S_IRUSR | stat.S_IWGRP | stat.S_IRGRP
1000        with os.fdopen(os.open(args.dst_file, flags, modes), 'w') as output_file:
1001            output_file.write(content)
1002
1003    def gen_seccomp_policy_code(self, args):
1004        if args.target_cpu not in supported_architecture:
1005            raise ValidateError('target cpu not supported')
1006        function_name_nr_table_dict = {}
1007        for file_name in args.src_files:
1008            file_name_tmp = file_name.split('/')[-1]
1009            if not file_name_tmp.lower().startswith('libsyscall_to_nr_'):
1010                continue
1011            function_name_nr_table_dict = gen_syscall_nr_table(file_name, function_name_nr_table_dict)
1012
1013
1014        for arch in supported_architecture:
1015            self.seccomp_policy_param.update(
1016                {arch: SeccompPolicyParam(arch, function_name_nr_table_dict.get(arch), args.is_debug)})
1017            self.reduced_block_list_parm.update(
1018                {arch: AllowBlockList(args.filter_name, arch, function_name_nr_table_dict.get(arch))})
1019
1020        self.bpf_generator.update_function_name_nr_table(function_name_nr_table_dict)
1021
1022        self.update_arch(args.target_cpu)
1023        self.update_is_debug(args.is_debug)
1024
1025        for file_name in args.blocklist_file:
1026            if file_name.lower().endswith('blocklist.seccomp.policy'):
1027                self.parse_file(file_name)
1028
1029        for file_name in args.keyprocess_file:
1030            if file_name.lower().endswith('privileged_process.seccomp.policy'):
1031                self.key_process_flag = True
1032                self.parse_file(file_name)
1033                self.key_process_flag = False
1034
1035        self.update_block_list()
1036
1037        for file_name in args.src_files:
1038            if file_name.lower().endswith('.policy'):
1039                self.parse_file(file_name)
1040
1041        if self.bpf_generator.get_gen_flag():
1042            self.gen_seccomp_policy()
1043
1044        if self.bpf_generator.get_gen_flag():
1045            self.gen_output_file(args)
1046
1047
1048def main():
1049    parser = argparse.ArgumentParser(
1050      description='Generates a seccomp-bpf policy')
1051    parser.add_argument('--src-files', type=str, action='append',
1052                        help=('The input files\n'))
1053
1054    parser.add_argument('--blocklist-file', type=str, action='append',
1055                        help=('input basic blocklist file(s)\n'))
1056
1057    parser.add_argument('--keyprocess-file', type=str, action='append',
1058                        help=('input key process file(s)\n'))
1059
1060    parser.add_argument('--dst-file',
1061                        help='The output path for the policy files')
1062
1063    parser.add_argument('--filter-name', type=str,
1064                        help='Name of seccomp bpf array generated by this script')
1065
1066    parser.add_argument('--target-cpu', type=str,
1067                        help=('please input target cpu arm or arm64\n'))
1068
1069    parser.add_argument('--is-debug', type=str,
1070                        help=('please input is_debug true or false\n'))
1071
1072    args = parser.parse_args()
1073
1074    generator = SeccompPolicyParser()
1075    generator.gen_seccomp_policy_code(args)
1076
1077
1078if __name__ == '__main__':
1079    sys.exit(main())
1080