1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4#
5# Copyright (c) 2023 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import sys
20import argparse
21import audit_log_analysis as audit_policy
22import generate_code_from_policy as gen_policy
23
24
25class MergePolicy:
26    def __init__(self):
27        self.cur_parse_item = ''
28        self.arches = set()
29        self.seccomp_policy_param = dict()
30
31    @staticmethod
32    def get_item_content(name_nr_table, item_str, itme_dict):
33        syscall_name_dict = {}
34        flag = False
35        for arch in gen_policy.supported_architecture:
36            func_name_to_nr = dict()
37            for item in itme_dict.get(arch):
38                if ':' in item:
39                    func_name = item[:item.find(':')].strip()
40                else:
41                    func_name = item
42                func_name_to_nr.update({item: name_nr_table.get(arch).get(func_name)})
43            func_name_to_nr_list = sorted(func_name_to_nr.items(), key=lambda x : x[1])
44
45            syscall_name_dict.update({arch: func_name_to_nr_list})
46        for arch in gen_policy.supported_architecture:
47            if syscall_name_dict.get(arch):
48                flag = True
49        if not flag:
50            return ''
51        content = '{}\n'.format(item_str)
52
53        for func_name, _ in syscall_name_dict.get('arm64'):
54            flag = False
55            for func_name_arm, nr_arm in syscall_name_dict.get('arm'):
56                if func_name == func_name_arm:
57                    content = '{}{};all\n'.format(content, func_name)
58                    syscall_name_dict.get('arm').remove((func_name, nr_arm))
59                    flag = True
60                    break
61            if not flag:
62                content = '{}{};arm64\n'.format(content, func_name)
63        if (syscall_name_dict.get('arm')):
64            content = '{}{};arm\n'.format(content, ';arm\n'.join(
65                      [func_name for func_name, _ in syscall_name_dict.get('arm')]))
66        if (syscall_name_dict.get('riscv64')):
67            content = '{}{};riscv64\n'.format(content, ';riscv64\n'.join(
68                      [func_name for func_name, _ in syscall_name_dict.get('riscv64')]))
69        return content
70
71    def update_parse_item(self, line):
72        item = line[1:]
73        if item in gen_policy.supported_parse_item:
74            self.cur_parse_item = item
75            print('start deal with {}'.format(self.cur_parse_item))
76
77    def parse_line(self, line):
78        if not self.cur_parse_item :
79            return
80        line = line.replace(' ', '')
81        pos = line.rfind(';')
82        if pos < 0:
83            for arch in self.arches:
84                self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line)
85        else:
86            arches = line[pos + 1:].split(',')
87            if arches[0] == 'all':
88                arches = gen_policy.supported_architecture
89            for arch in arches:
90                self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line[:pos])
91
92
93    def parse_open_file(self, fp):
94        for line in fp:
95            line = line.strip()
96            if not line:
97                continue
98            if line[0] == '#':
99                continue
100            if line[0] == '@':
101                self.update_parse_item(line)
102                continue
103            if line[0] != '@' and self.cur_parse_item == '':
104                continue
105            self.parse_line(line)
106
107    def parse_file(self, file_path):
108        with open(file_path) as fp:
109            self.parse_open_file(fp)
110
111    def merge_policy(self, args):
112        function_name_nr_table_dict = {}
113        for file_name in args.src_files:
114            file_name_tmp = file_name.split('/')[-1]
115            if not file_name_tmp.lower().startswith('libsyscall_to_nr_'):
116                continue
117            gen_policy.gen_syscall_nr_table(file_name, function_name_nr_table_dict)
118
119        for arch in gen_policy.supported_architecture:
120            self.seccomp_policy_param.update(\
121                {arch: gen_policy.SeccompPolicyParam(arch, function_name_nr_table_dict.get(arch))})
122
123        for file_name in args.src_files:
124            if file_name.lower().endswith('.policy'):
125                self.parse_file(file_name)
126
127        dict_priority = dict()
128        dict_allow_list = dict()
129        dict_priority_with_args = dict()
130        dict_allow_list_with_args = dict()
131        dict_blocklist = dict()
132
133        for arch in gen_policy.supported_architecture:
134            dict_priority.update({arch: self.seccomp_policy_param.get(arch).priority})
135            dict_allow_list.update({arch: self.seccomp_policy_param.get(arch).allow_list})
136            dict_priority_with_args.update({arch: self.seccomp_policy_param.get(arch).priority_with_args})
137            dict_allow_list_with_args.update({arch: self.seccomp_policy_param.get(arch).allow_list_with_args})
138            dict_blocklist.update({arch: self.seccomp_policy_param.get(arch).blocklist})
139
140        content = self.get_item_content(function_name_nr_table_dict, "@priority", dict_priority)
141        content += self.get_item_content(function_name_nr_table_dict, "@allowList", dict_allow_list)
142        content += self.get_item_content(function_name_nr_table_dict, "@priorityWithArgs", dict_priority_with_args)
143        content += self.get_item_content(function_name_nr_table_dict, "@allowListWithArgs", dict_allow_list_with_args)
144        content += self.get_item_content(function_name_nr_table_dict, "@blockList", dict_blocklist)
145        audit_policy.gen_output_file(args.filter_name, content)
146
147
148def main():
149    parser = argparse.ArgumentParser(
150      description='Generates a seccomp-bpf policy')
151    parser.add_argument('--src-files', type=str, action='append',
152                        help=('input libsyscall_to_nr files and policy filse\n'))
153
154    parser.add_argument('--filter-name', type=str,
155                        help='Name of seccomp bpf array generated by this script')
156
157    args = parser.parse_args()
158
159    generator = MergePolicy()
160    generator.merge_policy(args)
161
162
163if __name__ == '__main__':
164    sys.exit(main())
165