1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::error::Error; 15 use std::fmt::{Debug, Display, Formatter}; 16 use std::future::Future; 17 use std::pin::Pin; 18 use std::sync::atomic::AtomicUsize; 19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; 20 use std::task::Poll::{Pending, Ready}; 21 use std::task::{Context, Poll}; 22 23 use crate::sync::wake_list::WakerList; 24 25 /// Maximum capacity of `Semaphore`. 26 const MAX_PERMITS: usize = usize::MAX >> 1; 27 /// The least significant bit that marks the number of permits. 28 const PERMIT_SHIFT: usize = 1; 29 /// The flag marks that Semaphore is closed. 30 const CLOSED: usize = 1; 31 32 pub(crate) struct SemaphoreInner { 33 permits: AtomicUsize, 34 waker_list: WakerList, 35 } 36 37 pub(crate) struct Permit<'a> { 38 semaphore: &'a SemaphoreInner, 39 waker_index: Option<usize>, 40 enqueue: bool, 41 } 42 43 /// Error returned by `Semaphore`. 44 #[derive(Debug, Eq, PartialEq)] 45 pub enum SemaphoreError { 46 /// The number of Permits is overflowed. 47 Overflow, 48 /// Semaphore doesn't have enough permits. 49 Empty, 50 /// Semaphore was closed. 51 Closed, 52 } 53 54 impl Display for SemaphoreError { fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 56 match self { 57 SemaphoreError::Overflow => write!(f, "permit overflow MAX_PERMITS : {MAX_PERMITS}"), 58 SemaphoreError::Empty => write!(f, "no permits available"), 59 SemaphoreError::Closed => write!(f, "semaphore has been closed"), 60 } 61 } 62 } 63 64 impl Error for SemaphoreError {} 65 66 impl SemaphoreInner { new(permits: usize) -> Result<SemaphoreInner, SemaphoreError>67 pub(crate) fn new(permits: usize) -> Result<SemaphoreInner, SemaphoreError> { 68 if permits >= MAX_PERMITS { 69 return Err(SemaphoreError::Overflow); 70 } 71 Ok(SemaphoreInner { 72 permits: AtomicUsize::new(permits << PERMIT_SHIFT), 73 waker_list: WakerList::new(), 74 }) 75 } 76 current_permits(&self) -> usize77 pub(crate) fn current_permits(&self) -> usize { 78 self.permits.load(Acquire) >> PERMIT_SHIFT 79 } 80 release(&self)81 pub(crate) fn release(&self) { 82 // Get the lock first to ensure the atomicity of the two operations. 83 let mut waker_list = self.waker_list.lock(); 84 if !waker_list.notify_one() { 85 let prev = self.permits.fetch_add(1 << PERMIT_SHIFT, Release); 86 assert!( 87 (prev >> PERMIT_SHIFT) < MAX_PERMITS, 88 "the number of permits will overflow the capacity after addition" 89 ); 90 } 91 } 92 release_notify(&self)93 pub(crate) fn release_notify(&self) { 94 // Get the lock first to ensure the atomicity of the two operations. 95 let mut waker_list = self.waker_list.lock(); 96 if !waker_list.notify_one() { 97 self.permits.store(1 << PERMIT_SHIFT, Release); 98 } 99 } 100 release_multi(&self, mut permits: usize)101 pub(crate) fn release_multi(&self, mut permits: usize) { 102 let mut waker_list = self.waker_list.lock(); 103 while permits > 0 && waker_list.notify_one() { 104 permits -= 1; 105 } 106 let prev = self.permits.fetch_add(permits << PERMIT_SHIFT, Release); 107 assert!( 108 (prev >> PERMIT_SHIFT) < MAX_PERMITS, 109 "the number of permits will overflow the capacity after addition" 110 ); 111 } 112 release_all(&self)113 pub(crate) fn release_all(&self) { 114 self.waker_list.notify_all(); 115 } 116 try_acquire(&self) -> Result<(), SemaphoreError>117 pub(crate) fn try_acquire(&self) -> Result<(), SemaphoreError> { 118 let mut curr = self.permits.load(Acquire); 119 loop { 120 if curr & CLOSED == CLOSED { 121 return Err(SemaphoreError::Closed); 122 } 123 124 if curr == 0 { 125 return Err(SemaphoreError::Empty); 126 } 127 128 match self 129 .permits 130 .compare_exchange(curr, curr - (1 << PERMIT_SHIFT), AcqRel, Acquire) 131 { 132 Ok(_) => { 133 return Ok(()); 134 } 135 Err(actual) => { 136 curr = actual; 137 } 138 } 139 } 140 } 141 is_closed(&self) -> bool142 pub(crate) fn is_closed(&self) -> bool { 143 self.permits.load(Acquire) & CLOSED == CLOSED 144 } 145 close(&self)146 pub(crate) fn close(&self) { 147 // Get the lock first to ensure the atomicity of the two operations. 148 let mut waker_list = self.waker_list.lock(); 149 self.permits.fetch_or(CLOSED, Release); 150 waker_list.notify_all(); 151 } 152 acquire(&self) -> Permit<'_>153 pub(crate) fn acquire(&self) -> Permit<'_> { 154 Permit::new(self) 155 } 156 update_permit( &self, enqueue: &mut bool, curr: &mut usize, permit_num: usize, ) -> Option<Poll<Result<(), SemaphoreError>>>157 fn update_permit( 158 &self, 159 enqueue: &mut bool, 160 curr: &mut usize, 161 permit_num: usize, 162 ) -> Option<Poll<Result<(), SemaphoreError>>> { 163 match self 164 .permits 165 .compare_exchange(*curr, *curr - permit_num, AcqRel, Acquire) 166 { 167 Ok(_) => { 168 if *enqueue { 169 self.release(); 170 return Some(Pending); 171 } 172 return Some(Ready(Ok(()))); 173 } 174 Err(actual) => *curr = actual, 175 } 176 None 177 } 178 poll_acquire( &self, cx: &mut Context<'_>, waker_index: &mut Option<usize>, enqueue: &mut bool, ) -> Poll<Result<(), SemaphoreError>>179 fn poll_acquire( 180 &self, 181 cx: &mut Context<'_>, 182 waker_index: &mut Option<usize>, 183 enqueue: &mut bool, 184 ) -> Poll<Result<(), SemaphoreError>> { 185 let mut curr = self.permits.load(Acquire); 186 if curr & CLOSED == CLOSED { 187 return Ready(Err(SemaphoreError::Closed)); 188 } else if *enqueue { 189 *enqueue = false; 190 return Ready(Ok(())); 191 } 192 let permit_num = 1 << PERMIT_SHIFT; 193 loop { 194 if curr & CLOSED == CLOSED { 195 return Ready(Err(SemaphoreError::Closed)); 196 } 197 if curr >= permit_num { 198 if let Some(res) = self.update_permit(enqueue, &mut curr, permit_num) { 199 return res; 200 } 201 } else if !(*enqueue) { 202 *waker_index = Some(self.waker_list.insert(cx.waker().clone())); 203 *enqueue = true; 204 curr = self.permits.load(Acquire); 205 } else { 206 return Pending; 207 } 208 } 209 } 210 } 211 212 impl Debug for SemaphoreInner { fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result213 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 214 f.debug_struct("Semaphore") 215 .field("permits", &self.current_permits()) 216 .finish() 217 } 218 } 219 220 impl<'a> Permit<'a> { new(semaphore: &'a SemaphoreInner) -> Permit221 fn new(semaphore: &'a SemaphoreInner) -> Permit { 222 Permit { 223 semaphore, 224 waker_index: None, 225 enqueue: false, 226 } 227 } 228 } 229 230 impl Future for Permit<'_> { 231 type Output = Result<(), SemaphoreError>; 232 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>233 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 234 let (semaphore, waker_index, enqueue) = unsafe { 235 let me = self.get_unchecked_mut(); 236 (me.semaphore, &mut me.waker_index, &mut me.enqueue) 237 }; 238 239 semaphore.poll_acquire(cx, waker_index, enqueue) 240 } 241 } 242 243 impl Drop for Permit<'_> { drop(&mut self)244 fn drop(&mut self) { 245 if self.enqueue { 246 // if `enqueue` is true, `waker_index` must be `Some(_)`. 247 let _ = self.semaphore.waker_list.remove(self.waker_index.unwrap()); 248 } 249 } 250 } 251