1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cell::RefCell;
15 use std::future::Future;
16 use std::mem::MaybeUninit;
17 use std::pin::Pin;
18 use std::sync::atomic::AtomicUsize;
19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
20 use std::task::Poll::{Pending, Ready};
21 use std::task::{Context, Poll};
22 
23 use crate::sync::atomic_waker::AtomicWaker;
24 use crate::sync::error::{RecvError, SendError, TryRecvError, TrySendError};
25 use crate::sync::mpsc::Container;
26 use crate::sync::wake_list::WakerList;
27 
28 /// The offset of the index.
29 const INDEX_SHIFT: usize = 1;
30 /// The flag marks that Array is closed.
31 const CLOSED: usize = 0b01;
32 
33 pub(crate) struct Node<T> {
34     index: AtomicUsize,
35     value: RefCell<MaybeUninit<T>>,
36 }
37 
38 /// Bounded lockless queue.
39 pub(crate) struct Array<T> {
40     head: RefCell<usize>,
41     tail: AtomicUsize,
42     capacity: usize,
43     rx_waker: AtomicWaker,
44     waiters: WakerList,
45     data: Box<[Node<T>]>,
46 }
47 
48 unsafe impl<T: Send> Send for Array<T> {}
49 unsafe impl<T: Send> Sync for Array<T> {}
50 
51 pub(crate) enum SendPosition {
52     Pos(usize),
53     Full,
54     Closed,
55 }
56 
57 impl<T> Array<T> {
new(capacity: usize) -> Array<T>58     pub(crate) fn new(capacity: usize) -> Array<T> {
59         assert!(capacity > 0, "Capacity cannot be zero.");
60         let data = (0..capacity)
61             .map(|i| Node {
62                 index: AtomicUsize::new(i),
63                 value: RefCell::new(MaybeUninit::uninit()),
64             })
65             .collect();
66         Array {
67             head: RefCell::new(0),
68             tail: AtomicUsize::new(0),
69             capacity,
70             rx_waker: AtomicWaker::new(),
71             waiters: WakerList::new(),
72             data,
73         }
74     }
75 
prepare_send(&self) -> SendPosition76     fn prepare_send(&self) -> SendPosition {
77         let mut tail = self.tail.load(Acquire);
78         loop {
79             if tail & CLOSED == CLOSED {
80                 return SendPosition::Closed;
81             }
82             let index = (tail >> INDEX_SHIFT) % self.capacity;
83             // index is bounded by capacity, unwrap is safe
84             let node = self.data.get(index).unwrap();
85             let node_index = node.index.load(Acquire);
86 
87             // Compare the index of the node with the tail to avoid senders in different
88             // cycles writing data to the same point at the same time.
89             if (tail >> INDEX_SHIFT) == node_index {
90                 match self.tail.compare_exchange(
91                     tail,
92                     tail.wrapping_add(1 << INDEX_SHIFT),
93                     AcqRel,
94                     Acquire,
95                 ) {
96                     Ok(_) => return SendPosition::Pos(index),
97                     Err(actual) => tail = actual,
98                 }
99             } else {
100                 return SendPosition::Full;
101             }
102         }
103     }
104 
write(&self, index: usize, value: T)105     pub(crate) fn write(&self, index: usize, value: T) {
106         // index is bounded by capacity, unwrap is safe
107         let node = self.data.get(index).unwrap();
108         node.value.borrow_mut().write(value);
109 
110         // Mark that the node has data.
111         node.index.fetch_sub(1, Release);
112         self.rx_waker.wake();
113     }
114 
get_position(&self) -> SendPosition115     pub(crate) async fn get_position(&self) -> SendPosition {
116         Position { array: self }.await
117     }
118 
try_send(&self, value: T) -> Result<(), TrySendError<T>>119     pub(crate) fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
120         match self.prepare_send() {
121             SendPosition::Pos(index) => {
122                 self.write(index, value);
123                 Ok(())
124             }
125             SendPosition::Full => Err(TrySendError::Full(value)),
126             SendPosition::Closed => Err(TrySendError::Closed(value)),
127         }
128     }
129 
send(&self, value: T) -> Result<(), SendError<T>>130     pub(crate) async fn send(&self, value: T) -> Result<(), SendError<T>> {
131         match self.get_position().await {
132             SendPosition::Pos(index) => {
133                 self.write(index, value);
134                 Ok(())
135             }
136             SendPosition::Closed => Err(SendError(value)),
137             // If the array is full, the task will wait until it's available.
138             SendPosition::Full => unreachable!(),
139         }
140     }
141 
try_recv(&self) -> Result<T, TryRecvError>142     pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
143         let head = *self.head.borrow();
144         let index = head % self.capacity;
145         // index is bounded by capacity, unwrap is safe
146         let node = self.data.get(index).unwrap();
147         let node_index = node.index.load(Acquire);
148 
149         // Check whether the node has data.
150         if head == node_index.wrapping_add(1) {
151             let value = unsafe { node.value.as_ptr().read().assume_init() };
152             // Adding one indicates that this point is empty, Adding <capacity> enables the
153             // corresponding tail node to write in.
154             node.index.fetch_add(self.capacity + 1, Release);
155             self.waiters.notify_one();
156             self.head.replace(head + 1);
157             Ok(value)
158         } else {
159             let tail = self.tail.load(Acquire);
160             if tail & CLOSED == CLOSED {
161                 Err(TryRecvError::Closed)
162             } else {
163                 Err(TryRecvError::Empty)
164             }
165         }
166     }
167 
poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>>168     pub(crate) fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
169         match self.try_recv() {
170             Ok(val) => return Ready(Ok(val)),
171             Err(TryRecvError::Closed) => return Ready(Err(RecvError)),
172             Err(TryRecvError::Empty) => {}
173         }
174 
175         self.rx_waker.register_by_ref(cx.waker());
176 
177         match self.try_recv() {
178             Ok(val) => Ready(Ok(val)),
179             Err(TryRecvError::Closed) => Ready(Err(RecvError)),
180             Err(TryRecvError::Empty) => Pending,
181         }
182     }
183 
capacity(&self) -> usize184     pub(crate) fn capacity(&self) -> usize {
185         self.capacity
186     }
187 }
188 
189 impl<T> Container for Array<T> {
close(&self)190     fn close(&self) {
191         self.tail.fetch_or(CLOSED, Release);
192         self.waiters.notify_all();
193         self.rx_waker.wake();
194     }
195 
is_close(&self) -> bool196     fn is_close(&self) -> bool {
197         self.tail.load(Acquire) & CLOSED == CLOSED
198     }
199 
len(&self) -> usize200     fn len(&self) -> usize {
201         let head = *self.head.borrow();
202         let tail = self.tail.load(Acquire) >> INDEX_SHIFT;
203         tail - head
204     }
205 }
206 
207 impl<T> Drop for Array<T> {
drop(&mut self)208     fn drop(&mut self) {
209         let len = self.len();
210         if len == 0 {
211             return;
212         }
213         let head = *self.head.borrow();
214         for i in 0..len {
215             let mut index = head + i;
216             index %= self.capacity;
217             // index is bounded by capacity, unwrap is safe
218             let node = self.data.get(index).unwrap();
219             unsafe {
220                 node.value.borrow_mut().as_mut_ptr().drop_in_place();
221             }
222         }
223     }
224 }
225 
226 struct Position<'a, T> {
227     array: &'a Array<T>,
228 }
229 
230 impl<T> Future for Position<'_, T> {
231     type Output = SendPosition;
232 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>233     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
234         match self.array.prepare_send() {
235             SendPosition::Pos(index) => return Ready(SendPosition::Pos(index)),
236             SendPosition::Closed => return Ready(SendPosition::Closed),
237             SendPosition::Full => {}
238         }
239 
240         self.array.waiters.insert(cx.waker().clone());
241 
242         let tail = self.array.tail.load(Acquire);
243         let index = (tail >> INDEX_SHIFT) % self.array.capacity;
244         // index is bounded by capacity, unwrap is safe
245         let node = self.array.data.get(index).unwrap();
246         let node_index = node.index.load(Acquire);
247         if (tail >> INDEX_SHIFT) == node_index || tail & CLOSED == CLOSED {
248             self.array.waiters.notify_one();
249         }
250         Pending
251     }
252 }
253