1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::collections::VecDeque;
15 use std::future::Future;
16 use std::option::Option::Some;
17 use std::pin::Pin;
18 use std::sync::{Arc, Condvar, Mutex, MutexGuard, Weak};
19 use std::task::{Context, Poll};
20 use std::thread;
21 use std::time::Duration;
22 
23 use crate::builder::{CallbackHook, CommonBuilder};
24 use crate::error::{ErrorKind, ScheduleError};
25 use crate::executor::PlaceholderScheduler;
26 use crate::task;
27 use crate::task::{JoinHandle, TaskBuilder, VirtualTableType};
28 
29 pub(crate) const BLOCKING_THREAD_QUIT_WAIT_TIME: Duration = Duration::from_secs(1);
30 pub(crate) const DEFAULT_MAX_BLOCKING_POOL_SIZE: u8 = 16;
31 
32 #[derive(Clone)]
33 pub(crate) struct BlockPoolSpawner {
34     inner: Arc<Inner>,
35 }
36 
37 impl Drop for BlockPoolSpawner {
drop(&mut self)38     fn drop(&mut self) {
39         self.shutdown(BLOCKING_THREAD_QUIT_WAIT_TIME);
40     }
41 }
42 
43 impl BlockPoolSpawner {
new(builder: &CommonBuilder) -> BlockPoolSpawner44     pub fn new(builder: &CommonBuilder) -> BlockPoolSpawner {
45         let keep_alive_time = builder
46             .keep_alive_time
47             .unwrap_or(BLOCKING_THREAD_KEEP_ALIVE_TIME);
48         let max_thread_num = builder
49             .max_blocking_pool_size
50             .unwrap_or(DEFAULT_MAX_BLOCKING_POOL_SIZE);
51         BlockPoolSpawner {
52             inner: Arc::new(Inner {
53                 shared: Mutex::new(Shared {
54                     queue: VecDeque::new(),
55                     total_thread_num: 0,
56                     idle_thread_num: 0,
57                     notify_num: 0,
58                     current_permanent_thread_num: 0,
59                     shutdown: false,
60                     worker_id: 0,
61                     worker_threads: VecDeque::new(),
62                 }),
63                 condvar: Condvar::new(),
64                 shutdown_shared: Mutex::new(false),
65                 shutdown_condvar: Condvar::new(),
66                 stack_size: builder.stack_size,
67                 after_start: builder.after_start.clone(),
68                 before_stop: builder.before_stop.clone(),
69                 max_thread_num,
70                 keep_alive_time,
71                 max_permanent_thread_num: builder.blocking_permanent_thread_num,
72             }),
73         }
74     }
75 
shutdown(&mut self, timeout: Duration) -> bool76     pub fn shutdown(&mut self, timeout: Duration) -> bool {
77         let mut shared = self.inner.shared.lock().unwrap();
78 
79         if shared.shutdown {
80             return false;
81         }
82         self.inner.condvar.notify_all();
83         let workers = std::mem::take(&mut shared.worker_threads);
84         drop(shared);
85 
86         let shutdown_shared = self.inner.shutdown_shared.lock().unwrap();
87 
88         if *self
89             .inner
90             .shutdown_condvar
91             .wait_timeout(shutdown_shared, timeout)
92             .unwrap()
93             .0
94         {
95             for handle in workers {
96                 let _ = handle.1.join();
97             }
98             return true;
99         }
100         false
101     }
102 }
103 
104 const BLOCKING_THREAD_KEEP_ALIVE_TIME: Duration = Duration::from_secs(5);
105 
106 /// Inner struct for [`BlockPoolSpawner`].
107 struct Inner {
108     /// Shared information of the threads in the blocking pool
109     shared: Mutex<Shared>,
110 
111     /// Used for thread synchronization
112     condvar: Condvar,
113 
114     /// Stores the notification for shutting down
115     shutdown_shared: Mutex<bool>,
116 
117     /// Used for thread shutdown synchronization
118     shutdown_condvar: Condvar,
119 
120     /// Stack size of each thread in the blocking pool
121     stack_size: Option<usize>,
122 
123     /// A callback func to be called after thread starts
124     after_start: Option<CallbackHook>,
125 
126     /// A callback func to be called before thread stops
127     before_stop: Option<CallbackHook>,
128 
129     /// Maximum thread number for the blocking pool
130     max_thread_num: u8,
131 
132     /// Maximum keep-alive time for idle threads
133     keep_alive_time: Duration,
134 
135     /// Max number of permanent threads
136     max_permanent_thread_num: u8,
137 }
138 
139 /// Shared info among the blocking pool
140 struct Shared {
141     /// Task queue
142     queue: VecDeque<Task>,
143 
144     /// Number of current created threads
145     total_thread_num: u8,
146 
147     /// Number of current idle threads
148     idle_thread_num: u8,
149 
150     /// Number of calls to `notify_one`, prevents spurious wakeup of condvar.
151     notify_num: u8,
152 
153     /// number of permanent threads in the pool
154     current_permanent_thread_num: u8,
155 
156     /// Shutdown flag of the pool
157     shutdown: bool,
158 
159     /// Corresponds with the JoinHandles of the worker threads
160     worker_id: usize,
161 
162     /// Stores the JoinHandles of the worker threads
163     worker_threads: VecDeque<(usize, thread::JoinHandle<()>)>,
164 }
165 
166 type Task = task::Task;
167 
168 // ===== impl BlockPoolSpawner =====
169 impl BlockPoolSpawner {
create_permanent_threads(&self) -> Result<(), ScheduleError>170     pub fn create_permanent_threads(&self) -> Result<(), ScheduleError> {
171         for _ in 0..self.inner.max_permanent_thread_num {
172             let mut shared = self.inner.shared.lock().unwrap();
173             shared.total_thread_num += 1;
174             let worker_id = shared.worker_id;
175             let mut builder = thread::Builder::new().name(format!("block-r-{worker_id}"));
176             if let Some(stack_size) = self.inner.stack_size {
177                 builder = builder.stack_size(stack_size);
178             }
179             let inner = self.inner.clone();
180             let join_handle = builder.spawn(move || inner.run(worker_id));
181             match join_handle {
182                 Ok(join_handle) => {
183                     shared.worker_threads.push_back((worker_id, join_handle));
184                     shared.worker_id += 1;
185                 }
186                 Err(err) => {
187                     return Err(ScheduleError::new(ErrorKind::BlockSpawnErr, err));
188                 }
189             }
190         }
191         Ok(())
192     }
193 
spawn_blocking<T, R>(&self, builder: &TaskBuilder, task: T) -> JoinHandle<R> where T: FnOnce() -> R, T: Send + 'static, R: Send + 'static,194     pub(crate) fn spawn_blocking<T, R>(&self, builder: &TaskBuilder, task: T) -> JoinHandle<R>
195     where
196         T: FnOnce() -> R,
197         T: Send + 'static,
198         R: Send + 'static,
199     {
200         let task = BlockingTask(Some(task));
201         let scheduler: Weak<PlaceholderScheduler> = Weak::new();
202         let (task, handle) = Task::create_task(builder, scheduler, task, VirtualTableType::Ylong);
203         self.spawn(task);
204         handle
205     }
206 
spawn(&self, task: Task)207     fn spawn(&self, task: Task) {
208         let mut shared = self.inner.shared.lock().unwrap();
209 
210         // if the shutdown flag is on, cancel the task
211         assert!(
212             !shared.shutdown,
213             "The blocking runtime has already been shutdown, cannot spawn tasks"
214         );
215 
216         shared.queue.push_back(task);
217         // there are idle threads, wake up one
218         if shared.idle_thread_num != 0 {
219             shared.idle_thread_num -= 1;
220             shared.notify_num += 1;
221             self.inner.condvar.notify_one();
222             return;
223         }
224 
225         if shared.total_thread_num == self.inner.max_thread_num {
226             // thread number has reached maximum, do nothing
227             return;
228         }
229         // there is no idle thread and the maximum thread number has not been reached,
230         // therefore create a new thread
231         shared.total_thread_num += 1;
232         // sets all required attributes for the thread
233         let worker_id = shared.worker_id;
234         let mut builder = thread::Builder::new().name(format!("block-{worker_id}"));
235         if let Some(stack_size) = self.inner.stack_size {
236             builder = builder.stack_size(stack_size);
237         }
238 
239         let inner = self.inner.clone();
240         let join_handle = builder.spawn(move || inner.run(worker_id));
241         match join_handle {
242             Ok(join_handle) => {
243                 shared.worker_threads.push_back((worker_id, join_handle));
244                 shared.worker_id += 1;
245             }
246             Err(e) => {
247                 panic!("os can't spawn worker thread: {e}");
248             }
249         }
250     }
251 }
252 
253 enum WaitState {
254     Continue,
255     ExitWait,
256     Release,
257 }
258 
259 impl<'a> Inner {
260     // return true if it is not a spurious wakeup
wait_permanent(&'a self, mut shared: MutexGuard<'a, Shared>) -> (bool, MutexGuard<Shared>)261     fn wait_permanent(&'a self, mut shared: MutexGuard<'a, Shared>) -> (bool, MutexGuard<Shared>) {
262         shared.current_permanent_thread_num += 1;
263         shared = self.condvar.wait(shared).unwrap();
264         shared.current_permanent_thread_num -= 1;
265         // Combining a loop to prevent spurious wakeup of condvar, if there is a
266         // spurious wakeup, the `notify_num` will be 0 and the loop will continue.
267         if shared.notify_num != 0 {
268             shared.notify_num -= 1;
269             return (true, shared);
270         }
271         (false, shared)
272     }
273 
wait_temporary( &'a self, mut shared: MutexGuard<'a, Shared>, worker_id: usize, ) -> (WaitState, MutexGuard<Shared>)274     fn wait_temporary(
275         &'a self,
276         mut shared: MutexGuard<'a, Shared>,
277         worker_id: usize,
278     ) -> (WaitState, MutexGuard<Shared>) {
279         // if the thread is not permanent, set the keep-alive time for releasing
280         // the thread
281         let time_out_lock_res = self
282             .condvar
283             .wait_timeout(shared, self.keep_alive_time)
284             .unwrap();
285         shared = time_out_lock_res.0;
286         let timeout_result = time_out_lock_res.1;
287 
288         // Combining a loop to prevent spurious wakeup of condvar, if there is a
289         // spurious wakeup, the `notify_num` will be 0 and the loop will continue.
290         if shared.notify_num != 0 {
291             shared.notify_num -= 1;
292             return (WaitState::ExitWait, shared);
293         }
294         // expires, release the thread
295         if !shared.shutdown && timeout_result.timed_out() {
296             for (thread_id, thread) in shared.worker_threads.iter().enumerate() {
297                 if thread.0 == worker_id {
298                     shared.worker_threads.remove(thread_id);
299                     break;
300                 }
301             }
302             return (WaitState::Release, shared);
303         }
304         (WaitState::Continue, shared)
305     }
306 
307     // returns true if this thread should get released
wait( &'a self, mut shared: MutexGuard<'a, Shared>, worker_id: usize, ) -> (bool, MutexGuard<Shared>)308     fn wait(
309         &'a self,
310         mut shared: MutexGuard<'a, Shared>,
311         worker_id: usize,
312     ) -> (bool, MutexGuard<Shared>) {
313         shared.idle_thread_num += 1;
314         while !shared.shutdown {
315             // permanent waits, the thread keep alive until shutdown.
316             if shared.current_permanent_thread_num < self.max_permanent_thread_num {
317                 let (is_waked_up, guard) = self.wait_permanent(shared);
318                 shared = guard;
319                 if is_waked_up {
320                     break;
321                 }
322                 continue;
323             }
324             match self.wait_temporary(shared, worker_id) {
325                 (WaitState::ExitWait, guard) => {
326                     shared = guard;
327                     break;
328                 }
329                 (WaitState::Continue, guard) => shared = guard,
330                 (WaitState::Release, guard) => return (true, guard),
331             }
332         }
333         (false, shared)
334     }
335 
run(&self, worker_id: usize)336     fn run(&self, worker_id: usize) {
337         if let Some(f) = &self.after_start {
338             f()
339         }
340 
341         let mut shared = self.shared.lock().unwrap();
342         loop {
343             // get a task from the global queue
344             while let Some(task) = shared.queue.pop_front() {
345                 drop(shared);
346                 task.run();
347                 shared = self.shared.lock().unwrap();
348             }
349 
350             let (is_released, guard) = self.wait(shared, worker_id);
351             shared = guard;
352             // if this thread should get released, break
353             if is_released {
354                 break;
355             }
356             if shared.shutdown {
357                 // empty the tasks in the global queue
358                 while let Some(_task) = shared.queue.pop_front() {
359                     drop(shared);
360                     shared = self.shared.lock().unwrap();
361                 }
362                 break;
363             }
364         }
365 
366         // thread exit, thread num should be maintained correctly
367         shared.total_thread_num = shared
368             .total_thread_num
369             .checked_sub(1)
370             .expect("total thread num underflowed");
371         shared.idle_thread_num = shared
372             .idle_thread_num
373             .checked_sub(1)
374             .expect("idle thread num underflowed");
375 
376         let shutdown = shared.shutdown;
377         drop(shared);
378 
379         if shutdown {
380             *self.shutdown_shared.lock().unwrap() = true;
381             self.shutdown_condvar.notify_one();
382         }
383 
384         if let Some(f) = &self.before_stop {
385             f()
386         }
387     }
388 }
389 
390 struct BlockingTask<T>(Option<T>);
391 
392 impl<T> Unpin for BlockingTask<T> {}
393 
394 impl<T, R> Future for BlockingTask<T>
395 where
396     T: FnOnce() -> R,
397 {
398     type Output = R;
399 
poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output>400     fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
401         // Task won't be polled again after finished
402         let func = self
403             .0
404             .take()
405             .expect("blocking tasks cannot be polled after finished");
406         Poll::Ready(func())
407     }
408 }
409 
410 #[cfg(test)]
411 mod test {
412     use std::time::Duration;
413 
414     use crate::builder::RuntimeBuilder;
415     use crate::executor::blocking_pool::{BlockPoolSpawner, DEFAULT_MAX_BLOCKING_POOL_SIZE};
416 
417     /// UT test cases for BlockPoolSpawner::new()
418     ///
419     /// # Brief
420     /// 1. Checking the parameters after initialization is completed.
421     #[test]
ut_blocking_pool_new()422     fn ut_blocking_pool_new() {
423         let thread_pool_builder =
424             RuntimeBuilder::new_multi_thread().keep_alive_time(Duration::from_secs(1));
425         let blocking_pool = BlockPoolSpawner::new(&thread_pool_builder.common);
426         assert_eq!(
427             blocking_pool.inner.stack_size,
428             thread_pool_builder.common.stack_size
429         );
430         assert_eq!(
431             blocking_pool.inner.max_thread_num,
432             DEFAULT_MAX_BLOCKING_POOL_SIZE
433         );
434         assert_eq!(
435             blocking_pool.inner.keep_alive_time,
436             thread_pool_builder.common.keep_alive_time.unwrap()
437         );
438         assert_eq!(
439             blocking_pool.inner.max_permanent_thread_num,
440             thread_pool_builder.common.blocking_permanent_thread_num
441         );
442     }
443 
444     /// UT test cases for BlockPoolSpawner::shutdown()
445     ///
446     /// # Brief
447     /// 1. When shared.shutdown is false, the thread is safely exited without a
448     ///    timeout
449     /// 2. When shared.shutdown is false, the thread is not safely exited in
450     ///    case of timeout
451     /// 3. When shared.shutdown is true, BlockPoolSpawner::shutdown returns
452     ///    directly, representing that the blocking thread pool has safely
453     ///    exited
454     #[test]
ut_blocking_pool_shutdown()455     fn ut_blocking_pool_shutdown() {
456         let thread_pool_builder = RuntimeBuilder::new_multi_thread();
457         let mut blocking_pool = BlockPoolSpawner::new(&thread_pool_builder.common);
458         blocking_pool.inner.shared.lock().unwrap().shutdown = true;
459         assert!(!blocking_pool.shutdown(Duration::from_secs(3)));
460 
461         let thread_pool_builder = RuntimeBuilder::new_multi_thread();
462         let mut blocking_pool = BlockPoolSpawner::new(&thread_pool_builder.common);
463         let spawner_inner_clone = blocking_pool.inner.clone();
464         let _thread = std::thread::spawn(move || {
465             *spawner_inner_clone.shutdown_shared.lock().unwrap() = true;
466             spawner_inner_clone.shutdown_condvar.notify_one();
467         });
468         assert!(blocking_pool.shutdown(Duration::from_secs(3)));
469 
470         let thread_pool_builder = RuntimeBuilder::new_multi_thread();
471         let mut blocking_pool = BlockPoolSpawner::new(&thread_pool_builder.common);
472         let spawner_inner_clone = blocking_pool.inner.clone();
473         let _thread = std::thread::spawn(move || {
474             spawner_inner_clone.shutdown_condvar.notify_one();
475         });
476 
477         blocking_pool.inner.shared.lock().unwrap().shutdown = true;
478         assert!(!blocking_pool.shutdown(Duration::from_secs(0)));
479     }
480 
481     /// UT test cases for BlockPoolSpawner::create_permanent_threads()
482     ///
483     /// # Brief
484     /// 1. self.inner.is_permanent == true, self.inner.worker_name.clone() !=
485     ///    None, self.inner.stack_size != None
486     /// 2. self.inner.is_permanent == true, self.inner.worker_name.clone() ==
487     ///    None, self.inner.stack_size == None
488     /// 3. self.inner.is_permanent == false
489     #[test]
ut_blocking_pool_spawner_create_permanent_threads()490     fn ut_blocking_pool_spawner_create_permanent_threads() {
491         let thread_pool_builder =
492             RuntimeBuilder::new_multi_thread().blocking_permanent_thread_num(4);
493         let blocking_pool = BlockPoolSpawner::new(&thread_pool_builder.common);
494         assert!(blocking_pool.create_permanent_threads().is_ok());
495         assert_eq!(blocking_pool.inner.shared.lock().unwrap().worker_id, 4);
496 
497         let thread_pool_builder =
498             RuntimeBuilder::new_multi_thread().blocking_permanent_thread_num(4);
499         let common = RuntimeBuilder::new_multi_thread().blocking_permanent_thread_num(4);
500         let blocking_pool = BlockPoolSpawner::new(&common.common);
501         assert!(blocking_pool.create_permanent_threads().is_ok());
502         assert_eq!(
503             blocking_pool.inner.shared.lock().unwrap().worker_id,
504             thread_pool_builder.common.blocking_permanent_thread_num as usize
505         );
506         assert_eq!(
507             blocking_pool
508                 .inner
509                 .shared
510                 .lock()
511                 .unwrap()
512                 .worker_threads
513                 .pop_front()
514                 .unwrap()
515                 .1
516                 .thread()
517                 .name()
518                 .unwrap(),
519             "block-r-0"
520         );
521 
522         let thread_pool_builder = RuntimeBuilder::new_multi_thread()
523             .blocking_permanent_thread_num(4)
524             .worker_name(String::from("test"));
525         let common = RuntimeBuilder::new_multi_thread()
526             .blocking_permanent_thread_num(4)
527             .worker_name(String::from("test"));
528         let blocking_pool = BlockPoolSpawner::new(&common.common);
529         assert!(blocking_pool.create_permanent_threads().is_ok());
530         assert_eq!(
531             blocking_pool.inner.shared.lock().unwrap().worker_id,
532             thread_pool_builder.common.blocking_permanent_thread_num as usize
533         );
534         assert_eq!(
535             blocking_pool
536                 .inner
537                 .shared
538                 .lock()
539                 .unwrap()
540                 .worker_threads
541                 .pop_front()
542                 .unwrap()
543                 .1
544                 .thread()
545                 .name()
546                 .unwrap(),
547             "block-r-0"
548         );
549     }
550 }
551