1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# 5# Copyright (c) 2023 Huawei Device Co., Ltd. 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17# 18 19 20import argparse 21import sys 22import generate_code_from_policy as gen_policy 23 24 25class LibcFuncUnit: 26 def __init__(self, arch, addr, func_name, nr): 27 self.nr = set() 28 self.nr |= nr 29 self.func_name = func_name 30 self.addr = addr 31 self.use_function = set() 32 self.arch = arch 33 34 def merge_nr(self, nr): 35 self.nr |= nr 36 37 def update_func_name(self, func_name): 38 self.func_name = func_name 39 40 def update_addr(self, addr): 41 self.addr = addr 42 43 def update_use_function(self, new_function): 44 self.use_function.add(new_function) 45 46 def print_info(self, name_nr_table_dict): 47 keys = list(name_nr_table_dict.get(self.arch).keys()) 48 values = list(name_nr_table_dict.get(self.arch).values()) 49 nrs = [keys[values.index(nr_item)] for nr_item in self.nr] 50 print('{}\t{}\t{} use function is {}'.format(self.addr, self.func_name, nrs, self.use_function)) 51 52 53def remove_head_zero(addr): 54 pos = 0 55 for ch in addr: 56 if ch != '0': 57 break 58 pos += 1 59 return addr[pos:] 60 61 62def line_find_syscall_nr(line, nr_set, nr_last): 63 nr = nr_last 64 is_find_nr = False 65 is_find_svc = True 66 if ';' in line: 67 nr_tmp, is_digit = gen_policy.str_convert_to_int(line[line.find('0x'):]) 68 else: 69 nr_tmp, is_digit = gen_policy.str_convert_to_int(line[line.rfind('#') + 1:]) 70 if is_digit and 'movt' in line: 71 nr = nr_tmp * 256 * 256 72 return nr, is_find_nr, is_find_svc 73 74 if is_digit and 'movw' in line: 75 nr = nr + nr_tmp 76 nr_tmp = nr 77 nr = nr_tmp 78 if is_digit: 79 nr_set.add(nr) 80 nr_tmp = 0 81 nr = 0 82 is_find_nr = True 83 is_find_svc = False 84 else: 85 is_find_nr = False 86 is_find_svc = False 87 88 return nr, is_find_nr, is_find_svc 89 90 91def get_direct_use_syscall_of_svc(arch, lines, func_list): 92 is_find_nr = False 93 is_find_svc = False 94 nr_set = set() 95 nr = 0 96 if arch == 'arm': 97 svc_reg = 'r7,' 98 svc_reg1 = 'r7, ' 99 elif arch == 'arm64': 100 svc_reg = 'x8,' 101 svc_reg1 = 'w8,' 102 elif arch == 'riscv64': 103 svc_reg = 'x5,' 104 svc_reg1 = 'x5,' 105 for line in reversed(lines): 106 line = line.strip() 107 if not line: 108 is_find_nr = False 109 is_find_svc = False 110 continue 111 112 if not is_find_svc and ('svc\t' in line or 'svc ' in line): 113 is_find_nr = False 114 is_find_svc = True 115 continue 116 117 if is_find_svc and 'mov' in line and (svc_reg in line or svc_reg1 in line): 118 nr, is_find_nr, is_find_svc = line_find_syscall_nr(line, nr_set, nr) 119 continue 120 121 if is_find_nr and line[-1] == ':': 122 addr = line[:line.find(' ')] 123 addr = remove_head_zero(addr) 124 func_name = line[line.find('<') + 1: line.rfind('>')] 125 func_list.append(LibcFuncUnit(arch, addr, func_name, nr_set)) 126 nr_set.clear() 127 is_find_nr = False 128 is_find_svc = False 129 130 131def get_direct_use_syscall_of_syscall(arch, lines, func_list): 132 is_find_syscall_nr = False 133 is_find_syscall = False 134 nr_tmp = set() 135 addr_list = [func.addr for func in func_list] 136 if arch == 'arm': 137 syscall_reg = 'r0,' 138 syscall_reg1 = 'r0,' 139 elif arch == 'arm64': 140 syscall_reg = 'x0,' 141 syscall_reg1 = 'w0,' 142 elif arch == 'riscv64': 143 syscall_reg = 'x17,' 144 syscall_reg1 = 'x17,' 145 146 for line in reversed(lines): 147 line = line.strip() 148 if not line: 149 is_find_syscall = False 150 is_find_syscall_nr = False 151 continue 152 153 if not is_find_syscall and ('<syscall>' in line or '<__syscall_cp>' in line): 154 is_find_syscall = True 155 is_find_syscall_nr = False 156 continue 157 158 if is_find_syscall and 'mov' in line and (syscall_reg in line or syscall_reg1 in line): 159 if ';' in line: 160 nr, is_digit = gen_policy.str_convert_to_int(line[line.find('0x'):]) 161 else: 162 nr, is_digit = gen_policy.str_convert_to_int(line[line.rfind('#') + 1:]) 163 if is_digit: 164 nr_tmp.add(nr) 165 is_find_syscall_nr = True 166 is_find_syscall = False 167 continue 168 169 if is_find_syscall_nr and line[-1] == ':': 170 addr = line[:line.find(' ')] 171 addr = remove_head_zero(addr) 172 func_name = line[line.find('<') + 1: line.rfind('>')] 173 174 try: 175 inedx = addr_list.index(addr) 176 func_list[inedx].merge_nr(nr_tmp) 177 except(ValueError): 178 func_list.append(LibcFuncUnit(arch, addr, func_name, nr_tmp)) 179 180 nr_tmp.clear() 181 is_find_syscall_nr = False 182 is_find_syscall = False 183 184 185def get_direct_use_syscall(arch, lines): 186 func_list = [] 187 get_direct_use_syscall_of_svc(arch, lines, func_list) 188 get_direct_use_syscall_of_syscall(arch, lines, func_list) 189 190 return func_list 191 192 193def get_call_graph(arch, lines, func_list): 194 is_find_function = False 195 addr_list = [func.addr for func in func_list] 196 for line in lines: 197 line = line.strip() 198 if not line: 199 is_find_function = False 200 continue 201 if not is_find_function and '<' in line and '>:' in line: 202 is_find_function = True 203 caller_addr = line[:line.find(' ')] 204 caller_addr = remove_head_zero(caller_addr) 205 caller_func_name = line[line.find('<') + 1: line.rfind('>')] 206 continue 207 208 if is_find_function: 209 line_info = line.split('\t') 210 if len(line_info) < 4: 211 continue 212 213 if not ('b' in line_info[2] and '<' in line_info[3]): 214 continue 215 216 addr = line_info[3][:line_info[3].find(' ')] 217 218 try: 219 callee_inedx = addr_list.index(addr) 220 except(ValueError): 221 continue 222 223 try: 224 caller_inedx = addr_list.index(caller_addr) 225 func_list[caller_inedx].merge_nr(func_list[callee_inedx].nr) 226 func_list[caller_inedx].update_use_function(func_list[callee_inedx].func_name) 227 except(ValueError): 228 func_list.append(LibcFuncUnit(arch, caller_addr, caller_func_name, func_list[callee_inedx].nr)) 229 func_list[-1].update_use_function(func_list[callee_inedx].func_name) 230 addr_list.append(caller_addr) 231 232 233def parse_file(arch, file_name): 234 with open(file_name) as fp: 235 lines = fp.readlines() 236 func_list = get_direct_use_syscall(arch, lines) 237 func_list_old_len = len(func_list) 238 func_list_new_len = -1 239 240 while func_list_old_len != func_list_new_len: 241 func_list_old_len = len(func_list) 242 get_call_graph(arch, lines, func_list) 243 func_list_new_len = len(func_list) 244 245 return func_list 246 247 248def get_syscall_map(arch, src_syscall_path, libc_path): 249 function_name_nr_table_dict = {} 250 for file_name in src_syscall_path: 251 file_name_tmp = file_name.split('/')[-1] 252 if not file_name_tmp.lower().startswith('libsyscall_to_nr_'): 253 continue 254 gen_policy.gen_syscall_nr_table(file_name, function_name_nr_table_dict) 255 func_map = [] 256 if libc_path.lower().endswith('libc.asm'): 257 func_map = parse_file(arch, libc_path) 258 return func_map 259 260 261def main(): 262 parser = argparse.ArgumentParser( 263 description='Generates a seccomp-bpf policy') 264 parser.add_argument('--src-syscall-path', type=str, action='append', 265 help=('The input files\n')) 266 parser.add_argument('--libc-asm-path', type=str, 267 help=('The input files\n')) 268 parser.add_argument('--target-cpu', type=str, 269 help=('The input files\n')) 270 271 args = parser.parse_args() 272 get_syscall_map(args.target_cpu, args.src_syscall_path, args.libc_asm_path) 273 274 275if __name__ == '__main__': 276 sys.exit(main()) 277