1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::io;
15 use std::sync::Arc;
16 
17 use libc::{c_int, SIGFPE, SIGILL, SIGSEGV};
18 #[cfg(not(windows))]
19 use libc::{siginfo_t, SIGKILL, SIGSTOP};
20 
21 use crate::sig_map::SigMap;
22 
23 /// These signals should not be handled at all due to POSIX settings or their
24 /// specialness
25 #[cfg(windows)]
26 pub const SIGNAL_BLOCK_LIST: &[c_int] = &[SIGILL, SIGFPE, SIGSEGV];
27 
28 /// These signals should not be handled at all due to POSIX settings or their
29 /// specialness
30 #[cfg(not(windows))]
31 pub const SIGNAL_BLOCK_LIST: &[c_int] = &[SIGSEGV, SIGKILL, SIGSTOP, SIGILL, SIGFPE];
32 
33 #[cfg(windows)]
34 type Action = libc::sighandler_t;
35 #[cfg(not(windows))]
36 type Action = libc::sigaction;
37 
38 #[cfg(not(windows))]
39 use crate::unix::sig_handler;
40 #[cfg(windows)]
41 use crate::windows::sig_handler;
42 
43 #[cfg(windows)]
44 type ActionPtr = libc::sighandler_t;
45 #[cfg(not(windows))]
46 type ActionPtr = usize;
47 
48 #[allow(non_camel_case_types)]
49 #[cfg(windows)]
50 pub(crate) struct siginfo_t;
51 
52 type SigHandler = dyn Fn(&siginfo_t) + Send + Sync;
53 
54 #[derive(Clone)]
55 pub(crate) struct Signal {
56     pub(crate) old_act: Action,
57     pub(crate) new_act: Option<Arc<SigHandler>>,
58 }
59 
60 pub(crate) struct SigAction {
61     pub(crate) sig_num: c_int,
62     pub(crate) act: Action,
63 }
64 
65 impl Signal {
new(sig_num: c_int, new_act: Arc<SigHandler>) -> io::Result<Signal>66     pub(crate) fn new(sig_num: c_int, new_act: Arc<SigHandler>) -> io::Result<Signal> {
67         let old_act = Self::replace_sigaction(sig_num, sig_handler as ActionPtr)?;
68 
69         Ok(Signal {
70             old_act,
71             new_act: Some(new_act),
72         })
73     }
74 
register_action<F>(sig_num: c_int, handler: F) -> io::Result<()> where F: Fn(&siginfo_t) + Sync + Send + 'static,75     pub(super) unsafe fn register_action<F>(sig_num: c_int, handler: F) -> io::Result<()>
76     where
77         F: Fn(&siginfo_t) + Sync + Send + 'static,
78     {
79         if SIGNAL_BLOCK_LIST.contains(&sig_num) {
80             return Err(io::ErrorKind::InvalidInput.into());
81         }
82 
83         let sig_map = SigMap::get_instance();
84         let act = Arc::new(handler);
85         let mut write_guard = sig_map.data.write();
86         let mut new_map = write_guard.clone();
87 
88         if let Some(signal) = new_map.get_mut(&sig_num) {
89             if signal.new_act.is_some() {
90                 return Err(io::ErrorKind::AlreadyExists.into());
91             } else {
92                 signal.new_act = Some(act);
93             }
94         } else {
95             let old_act = SigAction::get_old_action(sig_num)?;
96             sig_map.race_old.write().store(Some(old_act));
97 
98             let signal = Signal::new(sig_num, act)?;
99             new_map.insert(sig_num, signal);
100         }
101         write_guard.store(new_map);
102         Ok(())
103     }
104 
deregister_action(sig_num: c_int) -> io::Result<()>105     pub(super) fn deregister_action(sig_num: c_int) -> io::Result<()> {
106         let sig_map = SigMap::get_instance();
107         let mut write_guard = sig_map.data.write();
108         let mut new_map = write_guard.clone();
109         if let Some(signal) = new_map.remove(&sig_num) {
110             #[cfg(not(windows))]
111             Self::replace_sigaction(sig_num, signal.old_act.sa_sigaction)?;
112             #[cfg(windows)]
113             Self::replace_sigaction(sig_num, signal.old_act)?;
114         }
115         write_guard.store(new_map);
116         Ok(())
117     }
118 
deregister_hook(sig_num: c_int) -> io::Result<()>119     pub(super) fn deregister_hook(sig_num: c_int) -> io::Result<()> {
120         let global = SigMap::get_instance();
121         let mut write_guard = global.data.write();
122         let mut signal_map = write_guard.clone();
123 
124         Self::replace_sigaction(sig_num, libc::SIG_DFL)?;
125 
126         signal_map.remove(&sig_num);
127         write_guard.store(signal_map);
128         Ok(())
129     }
130 }
131 
132 #[cfg(test)]
133 mod test {
134     use std::sync::atomic::{AtomicUsize, Ordering};
135     use std::sync::Arc;
136 
137     use crate::common::Signal;
138 
139     /// UT for signal creation
140     ///
141     /// # Brief
142     /// 1. Create a new signal
143     /// 2. Check if the signal is initialized correctly
144     #[test]
145     #[cfg(target_os = "linux")]
ut_signal_new()146     fn ut_signal_new() {
147         let handler = |_info: &libc::siginfo_t| {
148             let a = 1;
149             assert_eq!(a, 1);
150         };
151         let handler = Arc::new(handler);
152         let signal = Signal::new(libc::SIGINT, handler).unwrap();
153         assert!(signal.new_act.is_some());
154 
155         let signal2 = signal.clone();
156         drop(signal);
157         assert!(signal2.new_act.is_some());
158     }
159 
160     /// UT for signal register and deregister
161     ///
162     /// # Brief
163     /// 1. Registers two different signals with actions that increment two
164     /// different    atomic usize.
165     /// 2. Manually raises the two signals, checks if the registered action
166     /// behave    correctly.
167     /// 3. Deregisters the action of the two signals
168     /// 4. Registers the same action for one of the signals again
169     /// 5. Manually raises the signal, checks if the registered action behave
170     ///    correctly
171     /// 6. Deregisters both signal's handler hook, checks if the return is ok.
172     #[test]
ut_signal_register()173     fn ut_signal_register() {
174         let value = Arc::new(AtomicUsize::new(0));
175         let value_cpy = value.clone();
176 
177         let value2 = Arc::new(AtomicUsize::new(10));
178         let value2_cpy = value2.clone();
179         let value2_cpy2 = value2.clone();
180 
181         let res = unsafe {
182             Signal::register_action(libc::SIGINT, move |_| {
183                 value_cpy.fetch_add(1, Ordering::Relaxed);
184             })
185         };
186         assert!(res.is_ok());
187 
188         let res = unsafe {
189             Signal::register_action(libc::SIGTERM, move |_| {
190                 value2_cpy.fetch_add(10, Ordering::Relaxed);
191             })
192         };
193         assert!(res.is_ok());
194         assert_eq!(value.load(Ordering::Relaxed), 0);
195 
196         unsafe { libc::raise(libc::SIGINT) };
197         assert_eq!(value.load(Ordering::Relaxed), 1);
198         assert_eq!(value2.load(Ordering::Relaxed), 10);
199 
200         unsafe { libc::raise(libc::SIGTERM) };
201         assert_eq!(value.load(Ordering::Relaxed), 1);
202         assert_eq!(value2.load(Ordering::Relaxed), 20);
203 
204         let res = Signal::deregister_action(libc::SIGTERM);
205         assert!(res.is_ok());
206 
207         Signal::deregister_action(libc::SIGINT).unwrap();
208 
209         let res = unsafe {
210             Signal::register_action(libc::SIGTERM, move |_| {
211                 value2_cpy2.fetch_add(20, Ordering::Relaxed);
212             })
213         };
214         assert!(res.is_ok());
215 
216         unsafe { libc::raise(libc::SIGTERM) };
217         assert_eq!(value2.load(Ordering::Relaxed), 40);
218 
219         let res = Signal::deregister_hook(libc::SIGTERM);
220         assert!(res.is_ok());
221 
222         let res = Signal::deregister_hook(libc::SIGINT);
223         assert!(res.is_ok());
224     }
225 }
226