1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4#
5# Copyright (c) 2023 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import argparse
20import sys
21import os
22import stat
23import generate_code_from_policy as gen_policy
24
25
26def parse_line(fp, arch_nr):
27    arch_id_map = {
28        '40000028': 'arm',
29        'c00000b7': 'arm64',
30        'c00000f3': 'riscv64'
31    }
32    for line in fp:
33        line = line.strip()
34        if 'audit' not in line or 'type=1326' not in line:
35            continue
36
37        pos = line.find(' syscall=')
38        arch_id = line[line.find('arch=') + 5 : pos]
39        syscall, _ = gen_policy.str_convert_to_int(line[pos + 9: line.find(' compat')])
40        arch_nr.get(arch_id_map.get(arch_id)).add(syscall)
41
42
43def get_item_content(name_nr_table, arch_nr_table):
44    content = '@allowList\n'
45    syscall_name_dict = {
46        'arm': list(),
47        'arm64': list(),
48        'riscv64': list()
49    }
50    supported_architecture = ['arm64', 'arm', 'riscv64']
51    for arch in supported_architecture:
52        for nr in sorted(list(arch_nr_table.get(arch))):
53            syscall_name = name_nr_table.get(arch).get(nr)
54            if not syscall_name:
55                raise ValueError('nr is not ilegal')
56            syscall_name_dict.get(arch).append(syscall_name)
57
58    for func_name in syscall_name_dict.get('arm64'):
59        if func_name in syscall_name_dict.get('arm'):
60            content = '{}{};all\n'.format(content, func_name)
61            syscall_name_dict.get('arm').remove(func_name)
62        else:
63            content = '{}{};arm64\n'.format(content, func_name)
64    if syscall_name_dict.get('arm'):
65        content = '{}{};arm\n'.format(content, ';arm\n'.join(
66                  [func_name for func_name in syscall_name_dict.get('arm')]))
67    if syscall_name_dict.get('riscv64'):
68        content = '{}{};riscv64\n'.format(content, ';riscv64\n'.join(
69                  [func_name for func_name in syscall_name_dict.get('riscv64')]))
70
71    return content
72
73
74def gen_output_file(filter_name, content):
75    flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
76    modes = stat.S_IWUSR | stat.S_IRUSR | stat.S_IWGRP | stat.S_IRGRP
77    with os.fdopen(os.open(filter_name + '.seccomp.policy', flags, modes), 'w') as output_file:
78        output_file.write(content)
79
80
81def parse_file(file_name, arch_nr):
82    with open(file_name) as f:
83        parse_line(f, arch_nr)
84
85
86def converse_fuction_name_nr(dict_dst, dict_src):
87    for arch in dict_src.keys():
88        dict_dst.update({arch: dict()})
89
90    for arch in dict_src.keys():
91        for key, value in dict_src.get(arch).items():
92            dict_dst.get(arch).update({value: key})
93    return dict_dst
94
95
96def parse_audit_log_to_policy(args):
97    file_list = extract_file_from_path(args.src_path)
98    function_name_nr_table_dict_tmp = {}
99    function_name_nr_table_dict = {}
100    arch_nr = {
101        'arm': set(),
102        'arm64': set(),
103        'riscv64': set()
104    }
105    for file_name in file_list:
106        file_name_tmp = file_name.split('/')[-1]
107        if not file_name_tmp.lower().startswith('libsyscall_to_nr_'):
108            continue
109        function_name_nr_table_dict_tmp = gen_policy.gen_syscall_nr_table(file_name, function_name_nr_table_dict_tmp)
110
111    converse_fuction_name_nr(function_name_nr_table_dict, function_name_nr_table_dict_tmp)
112
113    for file_name in file_list:
114        if file_name.lower().endswith('.audit.log'):
115            parse_file(file_name, arch_nr)
116
117    content = get_item_content(function_name_nr_table_dict, arch_nr)
118    gen_output_file(args.filter_name, content)
119
120
121def extract_file_from_path(dir_path):
122    file_list = []
123    for path in dir_path:
124        if path[-1] == '/':
125            print('input dir path can not end with /')
126            return []
127
128        if os.path.isdir(path):
129            # get file list
130            file_list_tmp = os.listdir(path)
131            file_list += ['{}/{}'.format(path, item) for item in file_list_tmp]
132
133    return file_list
134
135
136def main():
137    parser = argparse.ArgumentParser(
138      description='Generates a seccomp-bpf policy')
139    parser.add_argument('--src-path', action='append',
140                        help='path to syscall to nr files')
141    parser.add_argument('--filter-name', type=str,
142                        help=('The input files\n'))
143
144
145    args = parser.parse_args()
146    parse_audit_log_to_policy(args)
147
148
149if __name__ == '__main__':
150    sys.exit(main())
151