1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cell::{RefCell, UnsafeCell};
15 use std::collections::{HashSet, LinkedList};
16 use std::future::Future;
17 use std::hash::{Hash, Hasher};
18 use std::mem;
19 use std::mem::ManuallyDrop;
20 use std::pin::Pin;
21 use std::sync::{Arc, Mutex};
22 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
23 
24 use crate::error::ScheduleError;
25 use crate::spawn::spawn_async;
26 use crate::task::join_handle::CancelHandle;
27 use crate::task::{JoinHandle, Qos, TaskBuilder};
28 
29 /// A collection of tasks get spawned on a Ylong runtime
30 ///
31 /// A `JoinSet` will take over the `JoinHandle`s of the tasks when spawning, and
32 /// it can asynchronously wait for the completion of some or all of the tasks
33 /// inside the set. However, `JoinSet` is unordered, which means that tasks'
34 /// results will be returned in the order of their completion.
35 ///
36 /// All the tasks spawned via a `JoinSet` must have the same return type.
37 ///
38 /// # Example
39 ///
40 /// ```
41 /// use ylong_runtime::task::JoinSet;
42 ///
43 /// async fn join_set_spawn() {
44 ///     let mut set = JoinSet::new();
45 ///     set.spawn(async move { 0 });
46 ///     let ret = set.join_next().await.unwrap().unwrap();
47 ///     assert_eq!(ret, 0)
48 /// }
49 /// ```
50 #[derive(Default)]
51 pub struct JoinSet<R> {
52     list: Arc<Mutex<JoinList<R>>>,
53     builder: TaskBuilder,
54 }
55 
56 unsafe impl<R: Send> Send for JoinSet<R> {}
57 
58 unsafe impl<R: Send> Sync for JoinSet<R> {}
59 
60 pub(crate) struct JoinList<R> {
61     // Contains tasks not ready for polling
62     wait_list: HashSet<Arc<JoinEntry<R>>>,
63     // Contains tasks ready for polling
64     done_list: LinkedList<Arc<JoinEntry<R>>>,
65     // Waker of JoinSet, a ready task will wake the JoinSet it belongs to
66     waker: Option<Waker>,
67     len: usize,
68 }
69 
70 impl<R> Default for JoinList<R> {
default() -> Self71     fn default() -> Self {
72         JoinList {
73             wait_list: HashSet::default(),
74             done_list: LinkedList::default(),
75             waker: None,
76             len: 0,
77         }
78     }
79 }
80 
81 pub(crate) struct JoinEntry<R> {
82     // The JoinHandle of the task
83     handle: UnsafeCell<ManuallyDrop<JoinHandle<R>>>,
84     // The JoinList this task belongs to
85     list: Arc<Mutex<JoinList<R>>>,
86     // A flag to indicate which list this task is in.
87     // `true` means the entry is in the done list.
88     in_done: RefCell<bool>,
89 }
90 
91 impl<R> JoinSet<R> {
92     /// Creates a new JoinSet.
new() -> Self93     pub fn new() -> Self {
94         Self {
95             list: Default::default(),
96             builder: Default::default(),
97         }
98     }
99 }
100 
101 impl<R> PartialEq<Self> for JoinEntry<R> {
eq(&self, other: &Self) -> bool102     fn eq(&self, other: &Self) -> bool {
103         unsafe { (*(self.handle.get())).raw.eq(&(*(other.handle.get())).raw) }
104     }
105 }
106 
107 impl<R> Eq for JoinEntry<R> {}
108 
109 impl<R> Hash for JoinEntry<R> {
hash<H: Hasher>(&self, state: &mut H)110     fn hash<H: Hasher>(&self, state: &mut H) {
111         unsafe { (*self.handle.get()).raw.hash(state) }
112     }
113 }
114 
115 impl<R> JoinEntry<R> {
116     // When waking a JoinEntry, the entry will get popped out of the wait list and
117     // pushed into the ready list. The corresponding in_done flag will also be
118     // changed. Safety: it will take the list's lock before moving the entry, so
119     // it's concurrently safe.
wake_by_ref(entry: &Arc<JoinEntry<R>>)120     fn wake_by_ref(entry: &Arc<JoinEntry<R>>) {
121         let mut list = entry.list.lock().unwrap();
122         if !entry.in_done.replace(true) {
123             // We couldn't find the entry, meaning that the JoinSet has been dropped
124             // already. In this case, there is no need to push the entry back to
125             // the done list.
126             if !list.wait_list.remove(entry) {
127                 return;
128             }
129             list.done_list.push_back(entry.clone());
130             // Wake the JoinSet if an waker is set
131             if let Some(waker) = list.waker.take() {
132                 drop(list);
133                 waker.wake();
134             }
135         }
136     }
137 }
138 
139 impl<R> JoinSet<R> {
140     /// Spawns a task via a `JoinSet` onto a Ylong runtime. The task will start
141     /// immediately when `spawn` is called.
142     ///
143     /// # Panics
144     /// This method panics when calling outside of Ylong runtime.
145     ///
146     /// # Examples
147     ///
148     /// ```
149     /// use ylong_runtime::task::JoinSet;
150     /// ylong_runtime::block_on(async move {
151     ///     let mut set = JoinSet::new();
152     ///     let cancel_handle = set.spawn(async move { 1 });
153     ///     cancel_handle.cancel();
154     /// });
155     /// ```
spawn<T>(&mut self, task: T) -> CancelHandle where T: Future<Output = R> + Send + 'static, R: Send + 'static,156     pub fn spawn<T>(&mut self, task: T) -> CancelHandle
157     where
158         T: Future<Output = R> + Send + 'static,
159         R: Send + 'static,
160     {
161         self.spawn_inner(task, None)
162     }
163 
spawn_inner<T>(&mut self, task: T, builder: Option<&TaskBuilder>) -> CancelHandle where T: Future<Output = R> + Send + 'static, R: Send + 'static,164     fn spawn_inner<T>(&mut self, task: T, builder: Option<&TaskBuilder>) -> CancelHandle
165     where
166         T: Future<Output = R> + Send + 'static,
167         R: Send + 'static,
168     {
169         let handle = match builder {
170             None => spawn_async(&self.builder, task),
171             Some(builder) => builder.spawn(task),
172         };
173         let cancel = handle.get_cancel_handle();
174         let entry = Arc::new(JoinEntry {
175             handle: UnsafeCell::new(ManuallyDrop::new(handle)),
176             list: self.list.clone(),
177             in_done: RefCell::new(false),
178         });
179         let mut list = self.list.lock().unwrap();
180         list.len += 1;
181         list.wait_list.insert(entry.clone());
182         drop(list);
183         let waker = entry_into_waker(&entry);
184         unsafe {
185             (*entry.handle.get()).set_waker(&waker);
186         }
187         cancel
188     }
189 
190     /// Waits until one task inside the `JoinSet` completes and returns its
191     /// output.
192     ///
193     /// Returns `None` if there is no task inside the set.
194     ///
195     /// # Examples
196     ///
197     /// ```
198     /// use ylong_runtime::task::JoinSet;
199     /// ylong_runtime::block_on(async move {
200     ///     let mut set = JoinSet::new();
201     ///     set.spawn(async move { 1 });
202     ///     let ret = set.join_next().await.unwrap().unwrap();
203     ///     assert_eq!(ret, 1);
204     ///     // no more task, so this `join_next` will return none
205     ///     let ret = set.join_next().await;
206     ///     assert!(ret.is_none());
207     /// });
208     /// ```
join_next(&mut self) -> Option<Result<R, ScheduleError>>209     pub async fn join_next(&mut self) -> Option<Result<R, ScheduleError>> {
210         use crate::futures::poll_fn;
211         poll_fn(|cx| self.poll_join_next(cx)).await
212     }
213 
214     /// Waits for all tasks inside the set to finish.
join_all(&mut self) -> Result<(), ScheduleError>215     pub async fn join_all(&mut self) -> Result<(), ScheduleError> {
216         // todo: take the lock only once
217         let count = self.list.lock().unwrap().len;
218         for _ in 0..count {
219             match self.join_next().await {
220                 None => return Ok(()),
221                 Some(Ok(_)) => {}
222                 Some(Err(e)) => return Err(e),
223             }
224         }
225         Ok(())
226     }
227 
228     /// Cancels every tasks inside the JoinSet.
229     ///
230     /// If [`JoinSet::join_next`] is called after calling `cancel_all`, then it
231     /// would return `TaskCanceled` error.
232     ///
233     /// # Examples
234     /// ```
235     /// use ylong_runtime::task::JoinSet;
236     /// ylong_runtime::block_on(async move {
237     ///     let mut set = JoinSet::new();
238     ///     set.spawn(async move { 1 });
239     ///     set.cancel_all();
240     /// });
241     /// ```
cancel_all(&mut self)242     pub fn cancel_all(&mut self) {
243         let list = self.list.lock().unwrap();
244         for item in &list.done_list {
245             unsafe { (*item.handle.get()).cancel() }
246         }
247         for item in &list.wait_list {
248             unsafe { (*item.handle.get()).cancel() }
249         }
250     }
251 
252     /// Cancels every tasks inside and clears all entries inside
253     ///
254     /// # Examples
255     /// ```
256     /// use ylong_runtime::task::JoinSet;
257     /// ylong_runtime::block_on(async move {
258     ///     let mut set = JoinSet::new();
259     ///     set.spawn(async move { 1 });
260     ///     set.shutdown();
261     /// });
262     /// ```
shutdown(&mut self)263     pub async fn shutdown(&mut self) {
264         self.cancel_all();
265         while self.join_next().await.is_some() {}
266     }
267 
268     /// Creates a builder that configures task attributes. This builder could
269     /// spawn tasks with its attributes onto the JoinSet.
270     ///
271     /// # Examples
272     /// ```
273     /// use ylong_runtime::task::JoinSet;
274     /// ylong_runtime::block_on(async move {
275     ///     let mut set = JoinSet::new();
276     ///     let mut builder = set.build_task().name("hello".into());
277     ///     builder.spawn(async move { 1 });
278     /// });
279     /// ```
build_task(&mut self) -> Builder<'_, R>280     pub fn build_task(&mut self) -> Builder<'_, R> {
281         Builder::new(self)
282     }
283 
poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<R, ScheduleError>>>284     fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<R, ScheduleError>>> {
285         let mut list = self.list.lock().unwrap();
286 
287         // quick path: check if the set is empty, return none if true
288         if list.len == 0 {
289             return Poll::Ready(None);
290         }
291 
292         // set the joinset's waker if it's not set
293         let is_same_waker = match list.waker.as_ref() {
294             None => false,
295             Some(waker) => cx.waker().will_wake(waker),
296         };
297 
298         if !is_same_waker {
299             list.waker = Some(cx.waker().clone());
300         }
301 
302         // pop a ready task from the done list and poll it
303         if let Some(entry) = list.done_list.pop_front() {
304             drop(list);
305             let waker = entry_into_waker(&entry);
306             let mut ctx = Context::from_waker(&waker);
307             // We have to dereference the JoinHandle from the UnsafeCell in order to poll
308             // it. The lifetime of the handle is valid here since it's wrapped
309             // by a ManuallyDrop. It will only get dropped when the task returns
310             // ready, and by the time, the entry is also dropped, and could
311             // never be popped from the done list once again.
312             unsafe {
313                 match Pin::new(&mut **(entry.handle.get())).poll(&mut ctx) {
314                     Poll::Ready(res) => {
315                         let mut list = self.list.lock().unwrap();
316                         list.len -= 1;
317                         drop(list);
318                         // drop the JoinHandle and return it's result
319                         drop(ManuallyDrop::take(&mut *entry.handle.get()));
320                         Poll::Ready(Some(res))
321                     }
322                     Poll::Pending => {
323                         let mut list = self.list.lock().unwrap();
324                         // The future hasn't finished, push it back to wait-list
325                         let _ = entry.in_done.replace(false);
326                         list.wait_list.insert(entry);
327                         drop(list);
328                         cx.waker().wake_by_ref();
329                         Poll::Pending
330                     }
331                 }
332             }
333         } else {
334             // there is no task, return none
335             if list.len == 0 {
336                 Poll::Ready(None)
337             } else {
338                 // no ready task, return pending
339                 Poll::Pending
340             }
341         }
342     }
343 }
344 
345 /// A TaskBuilder for tasks that get spawned on a specific JoinSet
346 pub struct Builder<'a, R> {
347     builder: TaskBuilder,
348     set: &'a mut JoinSet<R>,
349 }
350 
351 impl<'a, R> Builder<'a, R> {
new(set: &'a mut JoinSet<R>) -> Builder<'a, R>352     pub(crate) fn new(set: &'a mut JoinSet<R>) -> Builder<'a, R> {
353         Builder {
354             builder: TaskBuilder::new(),
355             set,
356         }
357     }
358 
359     /// Sets the name for the tasks that are going to get spawned by this
360     /// JoinSet Builder
361     ///
362     /// # Examples
363     /// ```
364     /// use ylong_runtime::task::JoinSet;
365     /// ylong_runtime::block_on(async move {
366     ///     let mut set = JoinSet::new();
367     ///     let mut builder = set.build_task().name("hello".into());
368     ///     builder.spawn(async move { 1 });
369     /// });
370     /// ```
name(self, name: String) -> Self371     pub fn name(self, name: String) -> Self {
372         let builder = self.builder.name(name);
373         Self {
374             builder,
375             set: self.set,
376         }
377     }
378 
379     /// Sets the QOS for the tasks that are going to get spawned by this
380     /// JoinSet Builder
381     ///
382     /// # Examples
383     /// ```
384     /// use ylong_runtime::task::{JoinSet, Qos};
385     /// ylong_runtime::block_on(async move {
386     ///     let mut set = JoinSet::new();
387     ///     let mut builder = set.build_task().qos(Qos::UserInitiated);
388     ///     builder.spawn(async move { 1 });
389     /// });
390     /// ```
qos(self, qos: Qos) -> Self391     pub fn qos(self, qos: Qos) -> Self {
392         let builder = self.builder.qos(qos);
393         Self {
394             builder,
395             set: self.set,
396         }
397     }
398 
399     /// Spawns a task via a `JoinSet` onto a Ylong runtime. The task will start
400     /// immediately when `spawn` is called.
401     ///
402     /// # Panics
403     /// This method panics when calling outside of Ylong runtime.
404     /// # Examples
405     /// ```
406     /// use ylong_runtime::task::JoinSet;
407     /// ylong_runtime::block_on(async move {
408     ///     let mut set = JoinSet::new();
409     ///     let mut builder = set.build_task();
410     ///     builder.spawn(async move { 1 });
411     /// });
412     /// ```
spawn<T>(&mut self, task: T) -> CancelHandle where T: Future<Output = R> + Send + 'static, R: Send + 'static,413     pub fn spawn<T>(&mut self, task: T) -> CancelHandle
414     where
415         T: Future<Output = R> + Send + 'static,
416         R: Send + 'static,
417     {
418         self.set.spawn_inner(task, Some(&self.builder))
419     }
420 }
421 
422 /// Cancels all task inside, and frees all corresponding JoinHandle.
423 impl<R> Drop for JoinSet<R> {
drop(&mut self)424     fn drop(&mut self) {
425         let mut list = self.list.lock().unwrap();
426 
427         for item in &list.done_list {
428             unsafe {
429                 (*item.handle.get()).cancel();
430                 drop(ManuallyDrop::take(&mut *item.handle.get()));
431             }
432         }
433         for item in &list.wait_list {
434             unsafe {
435                 (*item.handle.get()).cancel();
436                 drop(ManuallyDrop::take(&mut *item.handle.get()));
437             }
438         }
439         // pop every entry inside to reduce the ref count of the list
440         while list.done_list.pop_back().is_some() {}
441         list.wait_list.drain();
442     }
443 }
444 
445 // Gets the vtable of the entry waker
get_entry_waker_table<R>() -> &'static RawWakerVTable446 fn get_entry_waker_table<R>() -> &'static RawWakerVTable {
447     &RawWakerVTable::new(
448         clone_entry::<R>,
449         wake_entry::<R>,
450         wake_entry_ref::<R>,
451         drop_entry::<R>,
452     )
453 }
454 
455 // Converts a entry reference into a Waker
entry_into_waker<R>(entry: &Arc<JoinEntry<R>>) -> Waker456 fn entry_into_waker<R>(entry: &Arc<JoinEntry<R>>) -> Waker {
457     let cpy = entry.clone();
458     let data = Arc::into_raw(cpy).cast::<()>();
459     unsafe { Waker::from_raw(RawWaker::new(data, get_entry_waker_table::<R>())) }
460 }
461 
clone_entry<R>(data: *const ()) -> RawWaker462 unsafe fn clone_entry<R>(data: *const ()) -> RawWaker {
463     // First increment the arc counter
464     let entry = Arc::from_raw(data.cast::<JoinEntry<R>>());
465     mem::forget(entry.clone());
466     // Construct the new waker
467     let data = Arc::into_raw(entry).cast::<()>();
468     RawWaker::new(data, get_entry_waker_table::<R>())
469 }
470 
wake_entry<R>(data: *const ())471 unsafe fn wake_entry<R>(data: *const ()) {
472     let entry = Arc::from_raw(data.cast::<JoinEntry<R>>());
473     JoinEntry::wake_by_ref(&entry);
474 }
475 
wake_entry_ref<R>(data: *const ())476 unsafe fn wake_entry_ref<R>(data: *const ()) {
477     let entry = ManuallyDrop::new(Arc::from_raw(data.cast::<JoinEntry<R>>()));
478     JoinEntry::wake_by_ref(&entry);
479 }
480 
drop_entry<R>(data: *const ())481 unsafe fn drop_entry<R>(data: *const ()) {
482     drop(Arc::from_raw(data.cast::<JoinEntry<R>>()))
483 }
484