1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::error::Error;
15 use std::fmt::{Debug, Display, Formatter};
16 use std::future::Future;
17 use std::pin::Pin;
18 use std::sync::atomic::AtomicUsize;
19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
20 use std::task::Poll::{Pending, Ready};
21 use std::task::{Context, Poll};
22 
23 use crate::sync::wake_list::WakerList;
24 
25 /// Maximum capacity of `Semaphore`.
26 const MAX_PERMITS: usize = usize::MAX >> 1;
27 /// The least significant bit that marks the number of permits.
28 const PERMIT_SHIFT: usize = 1;
29 /// The flag marks that Semaphore is closed.
30 const CLOSED: usize = 1;
31 
32 pub(crate) struct SemaphoreInner {
33     permits: AtomicUsize,
34     waker_list: WakerList,
35 }
36 
37 pub(crate) struct Permit<'a> {
38     semaphore: &'a SemaphoreInner,
39     waker_index: Option<usize>,
40     enqueue: bool,
41 }
42 
43 /// Error returned by `Semaphore`.
44 #[derive(Debug, Eq, PartialEq)]
45 pub enum SemaphoreError {
46     /// The number of Permits is overflowed.
47     Overflow,
48     /// Semaphore doesn't have enough permits.
49     Empty,
50     /// Semaphore was closed.
51     Closed,
52 }
53 
54 impl Display for SemaphoreError {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result55     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56         match self {
57             SemaphoreError::Overflow => write!(f, "permit overflow MAX_PERMITS : {MAX_PERMITS}"),
58             SemaphoreError::Empty => write!(f, "no permits available"),
59             SemaphoreError::Closed => write!(f, "semaphore has been closed"),
60         }
61     }
62 }
63 
64 impl Error for SemaphoreError {}
65 
66 impl SemaphoreInner {
new(permits: usize) -> Result<SemaphoreInner, SemaphoreError>67     pub(crate) fn new(permits: usize) -> Result<SemaphoreInner, SemaphoreError> {
68         if permits >= MAX_PERMITS {
69             return Err(SemaphoreError::Overflow);
70         }
71         Ok(SemaphoreInner {
72             permits: AtomicUsize::new(permits << PERMIT_SHIFT),
73             waker_list: WakerList::new(),
74         })
75     }
76 
current_permits(&self) -> usize77     pub(crate) fn current_permits(&self) -> usize {
78         self.permits.load(Acquire) >> PERMIT_SHIFT
79     }
80 
release(&self)81     pub(crate) fn release(&self) {
82         // Get the lock first to ensure the atomicity of the two operations.
83         let mut waker_list = self.waker_list.lock();
84         if !waker_list.notify_one() {
85             let prev = self.permits.fetch_add(1 << PERMIT_SHIFT, Release);
86             assert!(
87                 (prev >> PERMIT_SHIFT) < MAX_PERMITS,
88                 "the number of permits will overflow the capacity after addition"
89             );
90         }
91     }
92 
release_notify(&self)93     pub(crate) fn release_notify(&self) {
94         // Get the lock first to ensure the atomicity of the two operations.
95         let mut waker_list = self.waker_list.lock();
96         if !waker_list.notify_one() {
97             self.permits.store(1 << PERMIT_SHIFT, Release);
98         }
99     }
100 
release_multi(&self, mut permits: usize)101     pub(crate) fn release_multi(&self, mut permits: usize) {
102         let mut waker_list = self.waker_list.lock();
103         while permits > 0 && waker_list.notify_one() {
104             permits -= 1;
105         }
106         let prev = self.permits.fetch_add(permits << PERMIT_SHIFT, Release);
107         assert!(
108             (prev >> PERMIT_SHIFT) < MAX_PERMITS,
109             "the number of permits will overflow the capacity after addition"
110         );
111     }
112 
release_all(&self)113     pub(crate) fn release_all(&self) {
114         self.waker_list.notify_all();
115     }
116 
try_acquire(&self) -> Result<(), SemaphoreError>117     pub(crate) fn try_acquire(&self) -> Result<(), SemaphoreError> {
118         let mut curr = self.permits.load(Acquire);
119         loop {
120             if curr & CLOSED == CLOSED {
121                 return Err(SemaphoreError::Closed);
122             }
123 
124             if curr == 0 {
125                 return Err(SemaphoreError::Empty);
126             }
127 
128             match self
129                 .permits
130                 .compare_exchange(curr, curr - (1 << PERMIT_SHIFT), AcqRel, Acquire)
131             {
132                 Ok(_) => {
133                     return Ok(());
134                 }
135                 Err(actual) => {
136                     curr = actual;
137                 }
138             }
139         }
140     }
141 
is_closed(&self) -> bool142     pub(crate) fn is_closed(&self) -> bool {
143         self.permits.load(Acquire) & CLOSED == CLOSED
144     }
145 
close(&self)146     pub(crate) fn close(&self) {
147         // Get the lock first to ensure the atomicity of the two operations.
148         let mut waker_list = self.waker_list.lock();
149         self.permits.fetch_or(CLOSED, Release);
150         waker_list.notify_all();
151     }
152 
acquire(&self) -> Permit<'_>153     pub(crate) fn acquire(&self) -> Permit<'_> {
154         Permit::new(self)
155     }
156 
update_permit( &self, enqueue: &mut bool, curr: &mut usize, permit_num: usize, ) -> Option<Poll<Result<(), SemaphoreError>>>157     fn update_permit(
158         &self,
159         enqueue: &mut bool,
160         curr: &mut usize,
161         permit_num: usize,
162     ) -> Option<Poll<Result<(), SemaphoreError>>> {
163         match self
164             .permits
165             .compare_exchange(*curr, *curr - permit_num, AcqRel, Acquire)
166         {
167             Ok(_) => {
168                 if *enqueue {
169                     self.release();
170                     return Some(Pending);
171                 }
172                 return Some(Ready(Ok(())));
173             }
174             Err(actual) => *curr = actual,
175         }
176         None
177     }
178 
poll_acquire( &self, cx: &mut Context<'_>, waker_index: &mut Option<usize>, enqueue: &mut bool, ) -> Poll<Result<(), SemaphoreError>>179     fn poll_acquire(
180         &self,
181         cx: &mut Context<'_>,
182         waker_index: &mut Option<usize>,
183         enqueue: &mut bool,
184     ) -> Poll<Result<(), SemaphoreError>> {
185         let mut curr = self.permits.load(Acquire);
186         if curr & CLOSED == CLOSED {
187             return Ready(Err(SemaphoreError::Closed));
188         } else if *enqueue {
189             *enqueue = false;
190             return Ready(Ok(()));
191         }
192         let permit_num = 1 << PERMIT_SHIFT;
193         loop {
194             if curr & CLOSED == CLOSED {
195                 return Ready(Err(SemaphoreError::Closed));
196             }
197             if curr >= permit_num {
198                 if let Some(res) = self.update_permit(enqueue, &mut curr, permit_num) {
199                     return res;
200                 }
201             } else if !(*enqueue) {
202                 *waker_index = Some(self.waker_list.insert(cx.waker().clone()));
203                 *enqueue = true;
204                 curr = self.permits.load(Acquire);
205             } else {
206                 return Pending;
207             }
208         }
209     }
210 }
211 
212 impl Debug for SemaphoreInner {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result213     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214         f.debug_struct("Semaphore")
215             .field("permits", &self.current_permits())
216             .finish()
217     }
218 }
219 
220 impl<'a> Permit<'a> {
new(semaphore: &'a SemaphoreInner) -> Permit221     fn new(semaphore: &'a SemaphoreInner) -> Permit {
222         Permit {
223             semaphore,
224             waker_index: None,
225             enqueue: false,
226         }
227     }
228 }
229 
230 impl Future for Permit<'_> {
231     type Output = Result<(), SemaphoreError>;
232 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>233     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
234         let (semaphore, waker_index, enqueue) = unsafe {
235             let me = self.get_unchecked_mut();
236             (me.semaphore, &mut me.waker_index, &mut me.enqueue)
237         };
238 
239         semaphore.poll_acquire(cx, waker_index, enqueue)
240     }
241 }
242 
243 impl Drop for Permit<'_> {
drop(&mut self)244     fn drop(&mut self) {
245         if self.enqueue {
246             // if `enqueue` is true, `waker_index` must be `Some(_)`.
247             let _ = self.semaphore.waker_list.remove(self.waker_index.unwrap());
248         }
249     }
250 }
251