1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::cell::RefCell; 15 use std::future::Future; 16 use std::mem::MaybeUninit; 17 use std::pin::Pin; 18 use std::sync::atomic::AtomicUsize; 19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; 20 use std::task::Poll::{Pending, Ready}; 21 use std::task::{Context, Poll}; 22 23 use crate::sync::atomic_waker::AtomicWaker; 24 use crate::sync::error::{RecvError, SendError, TryRecvError, TrySendError}; 25 use crate::sync::mpsc::Container; 26 use crate::sync::wake_list::WakerList; 27 28 /// The offset of the index. 29 const INDEX_SHIFT: usize = 1; 30 /// The flag marks that Array is closed. 31 const CLOSED: usize = 0b01; 32 33 pub(crate) struct Node<T> { 34 index: AtomicUsize, 35 value: RefCell<MaybeUninit<T>>, 36 } 37 38 /// Bounded lockless queue. 39 pub(crate) struct Array<T> { 40 head: RefCell<usize>, 41 tail: AtomicUsize, 42 capacity: usize, 43 rx_waker: AtomicWaker, 44 waiters: WakerList, 45 data: Box<[Node<T>]>, 46 } 47 48 unsafe impl<T: Send> Send for Array<T> {} 49 unsafe impl<T: Send> Sync for Array<T> {} 50 51 pub(crate) enum SendPosition { 52 Pos(usize), 53 Full, 54 Closed, 55 } 56 57 impl<T> Array<T> { new(capacity: usize) -> Array<T>58 pub(crate) fn new(capacity: usize) -> Array<T> { 59 assert!(capacity > 0, "Capacity cannot be zero."); 60 let data = (0..capacity) 61 .map(|i| Node { 62 index: AtomicUsize::new(i), 63 value: RefCell::new(MaybeUninit::uninit()), 64 }) 65 .collect(); 66 Array { 67 head: RefCell::new(0), 68 tail: AtomicUsize::new(0), 69 capacity, 70 rx_waker: AtomicWaker::new(), 71 waiters: WakerList::new(), 72 data, 73 } 74 } 75 prepare_send(&self) -> SendPosition76 fn prepare_send(&self) -> SendPosition { 77 let mut tail = self.tail.load(Acquire); 78 loop { 79 if tail & CLOSED == CLOSED { 80 return SendPosition::Closed; 81 } 82 let index = (tail >> INDEX_SHIFT) % self.capacity; 83 // index is bounded by capacity, unwrap is safe 84 let node = self.data.get(index).unwrap(); 85 let node_index = node.index.load(Acquire); 86 87 // Compare the index of the node with the tail to avoid senders in different 88 // cycles writing data to the same point at the same time. 89 if (tail >> INDEX_SHIFT) == node_index { 90 match self.tail.compare_exchange( 91 tail, 92 tail.wrapping_add(1 << INDEX_SHIFT), 93 AcqRel, 94 Acquire, 95 ) { 96 Ok(_) => return SendPosition::Pos(index), 97 Err(actual) => tail = actual, 98 } 99 } else { 100 return SendPosition::Full; 101 } 102 } 103 } 104 write(&self, index: usize, value: T)105 pub(crate) fn write(&self, index: usize, value: T) { 106 // index is bounded by capacity, unwrap is safe 107 let node = self.data.get(index).unwrap(); 108 node.value.borrow_mut().write(value); 109 110 // Mark that the node has data. 111 node.index.fetch_sub(1, Release); 112 self.rx_waker.wake(); 113 } 114 get_position(&self) -> SendPosition115 pub(crate) async fn get_position(&self) -> SendPosition { 116 Position { array: self }.await 117 } 118 try_send(&self, value: T) -> Result<(), TrySendError<T>>119 pub(crate) fn try_send(&self, value: T) -> Result<(), TrySendError<T>> { 120 match self.prepare_send() { 121 SendPosition::Pos(index) => { 122 self.write(index, value); 123 Ok(()) 124 } 125 SendPosition::Full => Err(TrySendError::Full(value)), 126 SendPosition::Closed => Err(TrySendError::Closed(value)), 127 } 128 } 129 send(&self, value: T) -> Result<(), SendError<T>>130 pub(crate) async fn send(&self, value: T) -> Result<(), SendError<T>> { 131 match self.get_position().await { 132 SendPosition::Pos(index) => { 133 self.write(index, value); 134 Ok(()) 135 } 136 SendPosition::Closed => Err(SendError(value)), 137 // If the array is full, the task will wait until it's available. 138 SendPosition::Full => unreachable!(), 139 } 140 } 141 try_recv(&self) -> Result<T, TryRecvError>142 pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> { 143 let head = *self.head.borrow(); 144 let index = head % self.capacity; 145 // index is bounded by capacity, unwrap is safe 146 let node = self.data.get(index).unwrap(); 147 let node_index = node.index.load(Acquire); 148 149 // Check whether the node has data. 150 if head == node_index.wrapping_add(1) { 151 let value = unsafe { node.value.as_ptr().read().assume_init() }; 152 // Adding one indicates that this point is empty, Adding <capacity> enables the 153 // corresponding tail node to write in. 154 node.index.fetch_add(self.capacity + 1, Release); 155 self.waiters.notify_one(); 156 self.head.replace(head + 1); 157 Ok(value) 158 } else { 159 let tail = self.tail.load(Acquire); 160 if tail & CLOSED == CLOSED { 161 Err(TryRecvError::Closed) 162 } else { 163 Err(TryRecvError::Empty) 164 } 165 } 166 } 167 poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>>168 pub(crate) fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { 169 match self.try_recv() { 170 Ok(val) => return Ready(Ok(val)), 171 Err(TryRecvError::Closed) => return Ready(Err(RecvError)), 172 Err(TryRecvError::Empty) => {} 173 } 174 175 self.rx_waker.register_by_ref(cx.waker()); 176 177 match self.try_recv() { 178 Ok(val) => Ready(Ok(val)), 179 Err(TryRecvError::Closed) => Ready(Err(RecvError)), 180 Err(TryRecvError::Empty) => Pending, 181 } 182 } 183 capacity(&self) -> usize184 pub(crate) fn capacity(&self) -> usize { 185 self.capacity 186 } 187 } 188 189 impl<T> Container for Array<T> { close(&self)190 fn close(&self) { 191 self.tail.fetch_or(CLOSED, Release); 192 self.waiters.notify_all(); 193 self.rx_waker.wake(); 194 } 195 is_close(&self) -> bool196 fn is_close(&self) -> bool { 197 self.tail.load(Acquire) & CLOSED == CLOSED 198 } 199 len(&self) -> usize200 fn len(&self) -> usize { 201 let head = *self.head.borrow(); 202 let tail = self.tail.load(Acquire) >> INDEX_SHIFT; 203 tail - head 204 } 205 } 206 207 impl<T> Drop for Array<T> { drop(&mut self)208 fn drop(&mut self) { 209 let len = self.len(); 210 if len == 0 { 211 return; 212 } 213 let head = *self.head.borrow(); 214 for i in 0..len { 215 let mut index = head + i; 216 index %= self.capacity; 217 // index is bounded by capacity, unwrap is safe 218 let node = self.data.get(index).unwrap(); 219 unsafe { 220 node.value.borrow_mut().as_mut_ptr().drop_in_place(); 221 } 222 } 223 } 224 } 225 226 struct Position<'a, T> { 227 array: &'a Array<T>, 228 } 229 230 impl<T> Future for Position<'_, T> { 231 type Output = SendPosition; 232 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>233 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 234 match self.array.prepare_send() { 235 SendPosition::Pos(index) => return Ready(SendPosition::Pos(index)), 236 SendPosition::Closed => return Ready(SendPosition::Closed), 237 SendPosition::Full => {} 238 } 239 240 self.array.waiters.insert(cx.waker().clone()); 241 242 let tail = self.array.tail.load(Acquire); 243 let index = (tail >> INDEX_SHIFT) % self.array.capacity; 244 // index is bounded by capacity, unwrap is safe 245 let node = self.array.data.get(index).unwrap(); 246 let node_index = node.index.load(Acquire); 247 if (tail >> INDEX_SHIFT) == node_index || tail & CLOSED == CLOSED { 248 self.array.waiters.notify_one(); 249 } 250 Pending 251 } 252 } 253