1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::sync::atomic::{AtomicUsize, Ordering};
15 use std::sync::Mutex;
16 
17 pub(crate) struct Sleeper {
18     record: Record,
19     idle_list: Mutex<Vec<usize>>,
20     num_workers: usize,
21     pub(crate) wake_by_search: Mutex<Vec<bool>>,
22 }
23 
24 impl Sleeper {
new(num_workers: usize) -> Self25     pub fn new(num_workers: usize) -> Self {
26         Sleeper {
27             record: Record::new(num_workers),
28             idle_list: Mutex::new(Vec::with_capacity(num_workers)),
29             num_workers,
30             wake_by_search: Mutex::new(vec![false; num_workers]),
31         }
32     }
33 
is_parked(&self, worker_index: &usize) -> bool34     pub fn is_parked(&self, worker_index: &usize) -> bool {
35         let idle_list = self.idle_list.lock().unwrap();
36         idle_list.contains(worker_index)
37     }
38 
pop_worker_by_id(&self, worker_index: &usize)39     pub fn pop_worker_by_id(&self, worker_index: &usize) {
40         let mut idle_list = self.idle_list.lock().unwrap();
41 
42         for i in 0..idle_list.len() {
43             if &idle_list[i] == worker_index {
44                 idle_list.swap_remove(i);
45                 self.record.inc_active_num();
46                 break;
47             }
48         }
49     }
50 
pop_worker(&self, last_search: bool) -> Option<usize>51     pub fn pop_worker(&self, last_search: bool) -> Option<usize> {
52         let (active_num, searching_num) = self.record.load_state();
53         if active_num >= self.num_workers || searching_num > 0 {
54             return None;
55         }
56 
57         let mut idle_list = self.idle_list.lock().unwrap();
58 
59         let res = idle_list.pop();
60         drop(idle_list);
61         if let Some(worker_idx) = res.as_ref() {
62             if last_search {
63                 let mut search_list = self.wake_by_search.lock().unwrap();
64                 search_list[*worker_idx] = true;
65             }
66             self.record.inc_active_num();
67         }
68 
69         res
70     }
71 
72     // return true if it's the last thread going to sleep.
push_worker(&self, worker_index: usize) -> bool73     pub fn push_worker(&self, worker_index: usize) -> bool {
74         let mut idle_list = self.idle_list.lock().unwrap();
75 
76         idle_list.push(worker_index);
77         self.record.dec_active_num()
78     }
79 
80     #[inline]
inc_searching_num(&self)81     pub fn inc_searching_num(&self) {
82         self.record.inc_searching_num();
83     }
84 
try_inc_searching_num(&self) -> bool85     pub fn try_inc_searching_num(&self) -> bool {
86         let (active_num, searching_num) = self.record.load_state();
87 
88         if searching_num * 2 < active_num {
89             // increment searching worker number
90             self.inc_searching_num();
91             return true;
92         }
93         false
94     }
95 
96     // return true if it's the last searching thread
97     #[inline]
dec_searching_num(&self) -> bool98     pub fn dec_searching_num(&self) -> bool {
99         self.record.dec_searching_num()
100     }
101 }
102 
103 const ACTIVE_WORKER_SHIFT: usize = 16;
104 const SEARCHING_MASK: usize = (1 << ACTIVE_WORKER_SHIFT) - 1;
105 const ACTIVE_MASK: usize = !SEARCHING_MASK;
106 //        32 bits          16 bits       16 bits
107 // |-------------------| working num | searching num|
108 struct Record(AtomicUsize);
109 
110 impl Record {
new(active_num: usize) -> Self111     fn new(active_num: usize) -> Self {
112         Self(AtomicUsize::new(active_num << ACTIVE_WORKER_SHIFT))
113     }
114 
115     // Return true if it is the last searching thread
dec_searching_num(&self) -> bool116     fn dec_searching_num(&self) -> bool {
117         let ret = self.0.fetch_sub(1, Ordering::SeqCst);
118         (ret & SEARCHING_MASK) == 1
119     }
120 
inc_searching_num(&self)121     fn inc_searching_num(&self) {
122         self.0.fetch_add(1, Ordering::SeqCst);
123     }
124 
inc_active_num(&self)125     fn inc_active_num(&self) {
126         let inc = 1 << ACTIVE_WORKER_SHIFT;
127 
128         self.0.fetch_add(inc, Ordering::SeqCst);
129     }
130 
dec_active_num(&self) -> bool131     fn dec_active_num(&self) -> bool {
132         let dec = 1 << ACTIVE_WORKER_SHIFT;
133 
134         let ret = self.0.fetch_sub(dec, Ordering::SeqCst);
135         let active_num = ((ret & ACTIVE_MASK) >> ACTIVE_WORKER_SHIFT) - 1;
136         active_num == 0
137     }
138 
load_state(&self) -> (usize, usize)139     fn load_state(&self) -> (usize, usize) {
140         let union_num = self.0.load(Ordering::SeqCst);
141 
142         let searching_num = union_num & SEARCHING_MASK;
143         let active_num = (union_num & ACTIVE_MASK) >> ACTIVE_WORKER_SHIFT;
144 
145         (active_num, searching_num)
146     }
147 }
148