1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4#
5# Copyright (c) 2023 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19
20import argparse
21import sys
22import generate_code_from_policy as gen_policy
23
24
25class LibcFuncUnit:
26    def __init__(self, arch, addr, func_name, nr):
27        self.nr = set()
28        self.nr |= nr
29        self.func_name = func_name
30        self.addr = addr
31        self.use_function = set()
32        self.arch = arch
33
34    def merge_nr(self, nr):
35        self.nr |= nr
36
37    def update_func_name(self, func_name):
38        self.func_name = func_name
39
40    def update_addr(self, addr):
41        self.addr = addr
42
43    def update_use_function(self, new_function):
44        self.use_function.add(new_function)
45
46    def print_info(self, name_nr_table_dict):
47        keys = list(name_nr_table_dict.get(self.arch).keys())
48        values = list(name_nr_table_dict.get(self.arch).values())
49        nrs = [keys[values.index(nr_item)] for nr_item in self.nr]
50        print('{}\t{}\t{} use function is {}'.format(self.addr, self.func_name, nrs, self.use_function))
51
52
53def remove_head_zero(addr):
54    pos = 0
55    for ch in addr:
56        if ch != '0':
57            break
58        pos += 1
59    return addr[pos:]
60
61
62def line_find_syscall_nr(line, nr_set, nr_last):
63    nr = nr_last
64    is_find_nr = False
65    is_find_svc = True
66    if ';' in line:
67        nr_tmp, is_digit = gen_policy.str_convert_to_int(line[line.find('0x'):])
68    else:
69        nr_tmp, is_digit = gen_policy.str_convert_to_int(line[line.rfind('#') + 1:])
70    if is_digit and 'movt' in line:
71        nr = nr_tmp * 256 * 256
72        return nr, is_find_nr, is_find_svc
73
74    if is_digit and 'movw' in line:
75        nr = nr + nr_tmp
76        nr_tmp = nr
77    nr = nr_tmp
78    if is_digit:
79        nr_set.add(nr)
80        nr_tmp = 0
81        nr = 0
82        is_find_nr = True
83        is_find_svc = False
84    else:
85        is_find_nr = False
86        is_find_svc = False
87
88    return nr, is_find_nr, is_find_svc
89
90
91def get_direct_use_syscall_of_svc(arch, lines, func_list):
92    is_find_nr = False
93    is_find_svc = False
94    nr_set = set()
95    nr = 0
96    if arch == 'arm':
97        svc_reg = 'r7,'
98        svc_reg1 = 'r7, '
99    elif arch == 'arm64':
100        svc_reg = 'x8,'
101        svc_reg1 = 'w8,'
102    elif arch == 'riscv64':
103        svc_reg = 'x5,'
104        svc_reg1 = 'x5,'
105    for line in reversed(lines):
106        line = line.strip()
107        if not line:
108            is_find_nr = False
109            is_find_svc = False
110            continue
111
112        if not is_find_svc and ('svc\t' in line or 'svc ' in line):
113            is_find_nr = False
114            is_find_svc = True
115            continue
116
117        if is_find_svc and 'mov' in line and (svc_reg in line or svc_reg1 in line):
118            nr, is_find_nr, is_find_svc = line_find_syscall_nr(line, nr_set, nr)
119            continue
120
121        if is_find_nr and line[-1] == ':':
122            addr = line[:line.find(' ')]
123            addr = remove_head_zero(addr)
124            func_name = line[line.find('<') + 1: line.rfind('>')]
125            func_list.append(LibcFuncUnit(arch, addr, func_name, nr_set))
126            nr_set.clear()
127            is_find_nr = False
128            is_find_svc = False
129
130
131def get_direct_use_syscall_of_syscall(arch, lines, func_list):
132    is_find_syscall_nr = False
133    is_find_syscall = False
134    nr_tmp = set()
135    addr_list = [func.addr for func in func_list]
136    if arch == 'arm':
137        syscall_reg = 'r0,'
138        syscall_reg1 = 'r0,'
139    elif arch == 'arm64':
140        syscall_reg = 'x0,'
141        syscall_reg1 = 'w0,'
142    elif arch == 'riscv64':
143        syscall_reg = 'x17,'
144        syscall_reg1 = 'x17,'
145
146    for line in reversed(lines):
147        line = line.strip()
148        if not line:
149            is_find_syscall = False
150            is_find_syscall_nr = False
151            continue
152
153        if not is_find_syscall and ('<syscall>' in line or '<__syscall_cp>' in line):
154            is_find_syscall = True
155            is_find_syscall_nr = False
156            continue
157
158        if is_find_syscall and 'mov' in line and (syscall_reg in line or syscall_reg1 in line):
159            if ';' in line:
160                nr, is_digit = gen_policy.str_convert_to_int(line[line.find('0x'):])
161            else:
162                nr, is_digit = gen_policy.str_convert_to_int(line[line.rfind('#') + 1:])
163            if is_digit:
164                nr_tmp.add(nr)
165                is_find_syscall_nr = True
166                is_find_syscall = False
167            continue
168
169        if is_find_syscall_nr and line[-1] == ':':
170            addr = line[:line.find(' ')]
171            addr = remove_head_zero(addr)
172            func_name = line[line.find('<') + 1: line.rfind('>')]
173
174            try:
175                inedx = addr_list.index(addr)
176                func_list[inedx].merge_nr(nr_tmp)
177            except(ValueError):
178                func_list.append(LibcFuncUnit(arch, addr, func_name, nr_tmp))
179
180            nr_tmp.clear()
181            is_find_syscall_nr = False
182            is_find_syscall = False
183
184
185def get_direct_use_syscall(arch, lines):
186    func_list = []
187    get_direct_use_syscall_of_svc(arch, lines, func_list)
188    get_direct_use_syscall_of_syscall(arch, lines, func_list)
189
190    return func_list
191
192
193def get_call_graph(arch, lines, func_list):
194    is_find_function = False
195    addr_list = [func.addr for func in func_list]
196    for line in lines:
197        line = line.strip()
198        if not line:
199            is_find_function = False
200            continue
201        if not is_find_function and '<' in line and '>:' in line:
202            is_find_function = True
203            caller_addr = line[:line.find(' ')]
204            caller_addr = remove_head_zero(caller_addr)
205            caller_func_name = line[line.find('<') + 1: line.rfind('>')]
206            continue
207
208        if is_find_function:
209            line_info = line.split('\t')
210            if len(line_info) < 4:
211                continue
212
213            if not ('b' in line_info[2] and '<' in line_info[3]):
214                continue
215
216            addr = line_info[3][:line_info[3].find(' ')]
217
218            try:
219                callee_inedx = addr_list.index(addr)
220            except(ValueError):
221                continue
222
223            try:
224                caller_inedx = addr_list.index(caller_addr)
225                func_list[caller_inedx].merge_nr(func_list[callee_inedx].nr)
226                func_list[caller_inedx].update_use_function(func_list[callee_inedx].func_name)
227            except(ValueError):
228                func_list.append(LibcFuncUnit(arch, caller_addr, caller_func_name, func_list[callee_inedx].nr))
229                func_list[-1].update_use_function(func_list[callee_inedx].func_name)
230                addr_list.append(caller_addr)
231
232
233def parse_file(arch, file_name):
234    with open(file_name) as fp:
235        lines = fp.readlines()
236        func_list = get_direct_use_syscall(arch, lines)
237        func_list_old_len = len(func_list)
238        func_list_new_len = -1
239
240        while func_list_old_len != func_list_new_len:
241            func_list_old_len = len(func_list)
242            get_call_graph(arch, lines, func_list)
243            func_list_new_len = len(func_list)
244
245    return func_list
246
247
248def get_syscall_map(arch, src_syscall_path, libc_path):
249    function_name_nr_table_dict = {}
250    for file_name in src_syscall_path:
251        file_name_tmp = file_name.split('/')[-1]
252        if not file_name_tmp.lower().startswith('libsyscall_to_nr_'):
253            continue
254        gen_policy.gen_syscall_nr_table(file_name, function_name_nr_table_dict)
255    func_map = []
256    if libc_path.lower().endswith('libc.asm'):
257        func_map = parse_file(arch, libc_path)
258    return func_map
259
260
261def main():
262    parser = argparse.ArgumentParser(
263      description='Generates a seccomp-bpf policy')
264    parser.add_argument('--src-syscall-path', type=str, action='append',
265                        help=('The input files\n'))
266    parser.add_argument('--libc-asm-path', type=str,
267                        help=('The input files\n'))
268    parser.add_argument('--target-cpu', type=str,
269                        help=('The input files\n'))
270
271    args = parser.parse_args()
272    get_syscall_map(args.target_cpu, args.src_syscall_path, args.libc_asm_path)
273
274
275if __name__ == '__main__':
276    sys.exit(main())
277