1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::cell::{RefCell, UnsafeCell};
15 use std::collections::{HashSet, LinkedList};
16 use std::future::Future;
17 use std::hash::{Hash, Hasher};
18 use std::mem;
19 use std::mem::ManuallyDrop;
20 use std::pin::Pin;
21 use std::sync::{Arc, Mutex};
22 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
23
24 use crate::error::ScheduleError;
25 use crate::spawn::spawn_async;
26 use crate::task::join_handle::CancelHandle;
27 use crate::task::{JoinHandle, Qos, TaskBuilder};
28
29 /// A collection of tasks get spawned on a Ylong runtime
30 ///
31 /// A `JoinSet` will take over the `JoinHandle`s of the tasks when spawning, and
32 /// it can asynchronously wait for the completion of some or all of the tasks
33 /// inside the set. However, `JoinSet` is unordered, which means that tasks'
34 /// results will be returned in the order of their completion.
35 ///
36 /// All the tasks spawned via a `JoinSet` must have the same return type.
37 ///
38 /// # Example
39 ///
40 /// ```
41 /// use ylong_runtime::task::JoinSet;
42 ///
43 /// async fn join_set_spawn() {
44 /// let mut set = JoinSet::new();
45 /// set.spawn(async move { 0 });
46 /// let ret = set.join_next().await.unwrap().unwrap();
47 /// assert_eq!(ret, 0)
48 /// }
49 /// ```
50 #[derive(Default)]
51 pub struct JoinSet<R> {
52 list: Arc<Mutex<JoinList<R>>>,
53 builder: TaskBuilder,
54 }
55
56 unsafe impl<R: Send> Send for JoinSet<R> {}
57
58 unsafe impl<R: Send> Sync for JoinSet<R> {}
59
60 pub(crate) struct JoinList<R> {
61 // Contains tasks not ready for polling
62 wait_list: HashSet<Arc<JoinEntry<R>>>,
63 // Contains tasks ready for polling
64 done_list: LinkedList<Arc<JoinEntry<R>>>,
65 // Waker of JoinSet, a ready task will wake the JoinSet it belongs to
66 waker: Option<Waker>,
67 len: usize,
68 }
69
70 impl<R> Default for JoinList<R> {
default() -> Self71 fn default() -> Self {
72 JoinList {
73 wait_list: HashSet::default(),
74 done_list: LinkedList::default(),
75 waker: None,
76 len: 0,
77 }
78 }
79 }
80
81 pub(crate) struct JoinEntry<R> {
82 // The JoinHandle of the task
83 handle: UnsafeCell<ManuallyDrop<JoinHandle<R>>>,
84 // The JoinList this task belongs to
85 list: Arc<Mutex<JoinList<R>>>,
86 // A flag to indicate which list this task is in.
87 // `true` means the entry is in the done list.
88 in_done: RefCell<bool>,
89 }
90
91 impl<R> JoinSet<R> {
92 /// Creates a new JoinSet.
new() -> Self93 pub fn new() -> Self {
94 Self {
95 list: Default::default(),
96 builder: Default::default(),
97 }
98 }
99 }
100
101 impl<R> PartialEq<Self> for JoinEntry<R> {
eq(&self, other: &Self) -> bool102 fn eq(&self, other: &Self) -> bool {
103 unsafe { (*(self.handle.get())).raw.eq(&(*(other.handle.get())).raw) }
104 }
105 }
106
107 impl<R> Eq for JoinEntry<R> {}
108
109 impl<R> Hash for JoinEntry<R> {
hash<H: Hasher>(&self, state: &mut H)110 fn hash<H: Hasher>(&self, state: &mut H) {
111 unsafe { (*self.handle.get()).raw.hash(state) }
112 }
113 }
114
115 impl<R> JoinEntry<R> {
116 // When waking a JoinEntry, the entry will get popped out of the wait list and
117 // pushed into the ready list. The corresponding in_done flag will also be
118 // changed. Safety: it will take the list's lock before moving the entry, so
119 // it's concurrently safe.
wake_by_ref(entry: &Arc<JoinEntry<R>>)120 fn wake_by_ref(entry: &Arc<JoinEntry<R>>) {
121 let mut list = entry.list.lock().unwrap();
122 if !entry.in_done.replace(true) {
123 // We couldn't find the entry, meaning that the JoinSet has been dropped
124 // already. In this case, there is no need to push the entry back to
125 // the done list.
126 if !list.wait_list.remove(entry) {
127 return;
128 }
129 list.done_list.push_back(entry.clone());
130 // Wake the JoinSet if an waker is set
131 if let Some(waker) = list.waker.take() {
132 drop(list);
133 waker.wake();
134 }
135 }
136 }
137 }
138
139 impl<R> JoinSet<R> {
140 /// Spawns a task via a `JoinSet` onto a Ylong runtime. The task will start
141 /// immediately when `spawn` is called.
142 ///
143 /// # Panics
144 /// This method panics when calling outside of Ylong runtime.
145 ///
146 /// # Examples
147 ///
148 /// ```
149 /// use ylong_runtime::task::JoinSet;
150 /// ylong_runtime::block_on(async move {
151 /// let mut set = JoinSet::new();
152 /// let cancel_handle = set.spawn(async move { 1 });
153 /// cancel_handle.cancel();
154 /// });
155 /// ```
spawn<T>(&mut self, task: T) -> CancelHandle where T: Future<Output = R> + Send + 'static, R: Send + 'static,156 pub fn spawn<T>(&mut self, task: T) -> CancelHandle
157 where
158 T: Future<Output = R> + Send + 'static,
159 R: Send + 'static,
160 {
161 self.spawn_inner(task, None)
162 }
163
spawn_inner<T>(&mut self, task: T, builder: Option<&TaskBuilder>) -> CancelHandle where T: Future<Output = R> + Send + 'static, R: Send + 'static,164 fn spawn_inner<T>(&mut self, task: T, builder: Option<&TaskBuilder>) -> CancelHandle
165 where
166 T: Future<Output = R> + Send + 'static,
167 R: Send + 'static,
168 {
169 let handle = match builder {
170 None => spawn_async(&self.builder, task),
171 Some(builder) => builder.spawn(task),
172 };
173 let cancel = handle.get_cancel_handle();
174 let entry = Arc::new(JoinEntry {
175 handle: UnsafeCell::new(ManuallyDrop::new(handle)),
176 list: self.list.clone(),
177 in_done: RefCell::new(false),
178 });
179 let mut list = self.list.lock().unwrap();
180 list.len += 1;
181 list.wait_list.insert(entry.clone());
182 drop(list);
183 let waker = entry_into_waker(&entry);
184 unsafe {
185 (*entry.handle.get()).set_waker(&waker);
186 }
187 cancel
188 }
189
190 /// Waits until one task inside the `JoinSet` completes and returns its
191 /// output.
192 ///
193 /// Returns `None` if there is no task inside the set.
194 ///
195 /// # Examples
196 ///
197 /// ```
198 /// use ylong_runtime::task::JoinSet;
199 /// ylong_runtime::block_on(async move {
200 /// let mut set = JoinSet::new();
201 /// set.spawn(async move { 1 });
202 /// let ret = set.join_next().await.unwrap().unwrap();
203 /// assert_eq!(ret, 1);
204 /// // no more task, so this `join_next` will return none
205 /// let ret = set.join_next().await;
206 /// assert!(ret.is_none());
207 /// });
208 /// ```
join_next(&mut self) -> Option<Result<R, ScheduleError>>209 pub async fn join_next(&mut self) -> Option<Result<R, ScheduleError>> {
210 use crate::futures::poll_fn;
211 poll_fn(|cx| self.poll_join_next(cx)).await
212 }
213
214 /// Waits for all tasks inside the set to finish.
join_all(&mut self) -> Result<(), ScheduleError>215 pub async fn join_all(&mut self) -> Result<(), ScheduleError> {
216 // todo: take the lock only once
217 let count = self.list.lock().unwrap().len;
218 for _ in 0..count {
219 match self.join_next().await {
220 None => return Ok(()),
221 Some(Ok(_)) => {}
222 Some(Err(e)) => return Err(e),
223 }
224 }
225 Ok(())
226 }
227
228 /// Cancels every tasks inside the JoinSet.
229 ///
230 /// If [`JoinSet::join_next`] is called after calling `cancel_all`, then it
231 /// would return `TaskCanceled` error.
232 ///
233 /// # Examples
234 /// ```
235 /// use ylong_runtime::task::JoinSet;
236 /// ylong_runtime::block_on(async move {
237 /// let mut set = JoinSet::new();
238 /// set.spawn(async move { 1 });
239 /// set.cancel_all();
240 /// });
241 /// ```
cancel_all(&mut self)242 pub fn cancel_all(&mut self) {
243 let list = self.list.lock().unwrap();
244 for item in &list.done_list {
245 unsafe { (*item.handle.get()).cancel() }
246 }
247 for item in &list.wait_list {
248 unsafe { (*item.handle.get()).cancel() }
249 }
250 }
251
252 /// Cancels every tasks inside and clears all entries inside
253 ///
254 /// # Examples
255 /// ```
256 /// use ylong_runtime::task::JoinSet;
257 /// ylong_runtime::block_on(async move {
258 /// let mut set = JoinSet::new();
259 /// set.spawn(async move { 1 });
260 /// set.shutdown();
261 /// });
262 /// ```
shutdown(&mut self)263 pub async fn shutdown(&mut self) {
264 self.cancel_all();
265 while self.join_next().await.is_some() {}
266 }
267
268 /// Creates a builder that configures task attributes. This builder could
269 /// spawn tasks with its attributes onto the JoinSet.
270 ///
271 /// # Examples
272 /// ```
273 /// use ylong_runtime::task::JoinSet;
274 /// ylong_runtime::block_on(async move {
275 /// let mut set = JoinSet::new();
276 /// let mut builder = set.build_task().name("hello".into());
277 /// builder.spawn(async move { 1 });
278 /// });
279 /// ```
build_task(&mut self) -> Builder<'_, R>280 pub fn build_task(&mut self) -> Builder<'_, R> {
281 Builder::new(self)
282 }
283
poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<R, ScheduleError>>>284 fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<R, ScheduleError>>> {
285 let mut list = self.list.lock().unwrap();
286
287 // quick path: check if the set is empty, return none if true
288 if list.len == 0 {
289 return Poll::Ready(None);
290 }
291
292 // set the joinset's waker if it's not set
293 let is_same_waker = match list.waker.as_ref() {
294 None => false,
295 Some(waker) => cx.waker().will_wake(waker),
296 };
297
298 if !is_same_waker {
299 list.waker = Some(cx.waker().clone());
300 }
301
302 // pop a ready task from the done list and poll it
303 if let Some(entry) = list.done_list.pop_front() {
304 drop(list);
305 let waker = entry_into_waker(&entry);
306 let mut ctx = Context::from_waker(&waker);
307 // We have to dereference the JoinHandle from the UnsafeCell in order to poll
308 // it. The lifetime of the handle is valid here since it's wrapped
309 // by a ManuallyDrop. It will only get dropped when the task returns
310 // ready, and by the time, the entry is also dropped, and could
311 // never be popped from the done list once again.
312 unsafe {
313 match Pin::new(&mut **(entry.handle.get())).poll(&mut ctx) {
314 Poll::Ready(res) => {
315 let mut list = self.list.lock().unwrap();
316 list.len -= 1;
317 drop(list);
318 // drop the JoinHandle and return it's result
319 drop(ManuallyDrop::take(&mut *entry.handle.get()));
320 Poll::Ready(Some(res))
321 }
322 Poll::Pending => {
323 let mut list = self.list.lock().unwrap();
324 // The future hasn't finished, push it back to wait-list
325 let _ = entry.in_done.replace(false);
326 list.wait_list.insert(entry);
327 drop(list);
328 cx.waker().wake_by_ref();
329 Poll::Pending
330 }
331 }
332 }
333 } else {
334 // there is no task, return none
335 if list.len == 0 {
336 Poll::Ready(None)
337 } else {
338 // no ready task, return pending
339 Poll::Pending
340 }
341 }
342 }
343 }
344
345 /// A TaskBuilder for tasks that get spawned on a specific JoinSet
346 pub struct Builder<'a, R> {
347 builder: TaskBuilder,
348 set: &'a mut JoinSet<R>,
349 }
350
351 impl<'a, R> Builder<'a, R> {
new(set: &'a mut JoinSet<R>) -> Builder<'a, R>352 pub(crate) fn new(set: &'a mut JoinSet<R>) -> Builder<'a, R> {
353 Builder {
354 builder: TaskBuilder::new(),
355 set,
356 }
357 }
358
359 /// Sets the name for the tasks that are going to get spawned by this
360 /// JoinSet Builder
361 ///
362 /// # Examples
363 /// ```
364 /// use ylong_runtime::task::JoinSet;
365 /// ylong_runtime::block_on(async move {
366 /// let mut set = JoinSet::new();
367 /// let mut builder = set.build_task().name("hello".into());
368 /// builder.spawn(async move { 1 });
369 /// });
370 /// ```
name(self, name: String) -> Self371 pub fn name(self, name: String) -> Self {
372 let builder = self.builder.name(name);
373 Self {
374 builder,
375 set: self.set,
376 }
377 }
378
379 /// Sets the QOS for the tasks that are going to get spawned by this
380 /// JoinSet Builder
381 ///
382 /// # Examples
383 /// ```
384 /// use ylong_runtime::task::{JoinSet, Qos};
385 /// ylong_runtime::block_on(async move {
386 /// let mut set = JoinSet::new();
387 /// let mut builder = set.build_task().qos(Qos::UserInitiated);
388 /// builder.spawn(async move { 1 });
389 /// });
390 /// ```
qos(self, qos: Qos) -> Self391 pub fn qos(self, qos: Qos) -> Self {
392 let builder = self.builder.qos(qos);
393 Self {
394 builder,
395 set: self.set,
396 }
397 }
398
399 /// Spawns a task via a `JoinSet` onto a Ylong runtime. The task will start
400 /// immediately when `spawn` is called.
401 ///
402 /// # Panics
403 /// This method panics when calling outside of Ylong runtime.
404 /// # Examples
405 /// ```
406 /// use ylong_runtime::task::JoinSet;
407 /// ylong_runtime::block_on(async move {
408 /// let mut set = JoinSet::new();
409 /// let mut builder = set.build_task();
410 /// builder.spawn(async move { 1 });
411 /// });
412 /// ```
spawn<T>(&mut self, task: T) -> CancelHandle where T: Future<Output = R> + Send + 'static, R: Send + 'static,413 pub fn spawn<T>(&mut self, task: T) -> CancelHandle
414 where
415 T: Future<Output = R> + Send + 'static,
416 R: Send + 'static,
417 {
418 self.set.spawn_inner(task, Some(&self.builder))
419 }
420 }
421
422 /// Cancels all task inside, and frees all corresponding JoinHandle.
423 impl<R> Drop for JoinSet<R> {
drop(&mut self)424 fn drop(&mut self) {
425 let mut list = self.list.lock().unwrap();
426
427 for item in &list.done_list {
428 unsafe {
429 (*item.handle.get()).cancel();
430 drop(ManuallyDrop::take(&mut *item.handle.get()));
431 }
432 }
433 for item in &list.wait_list {
434 unsafe {
435 (*item.handle.get()).cancel();
436 drop(ManuallyDrop::take(&mut *item.handle.get()));
437 }
438 }
439 // pop every entry inside to reduce the ref count of the list
440 while list.done_list.pop_back().is_some() {}
441 list.wait_list.drain();
442 }
443 }
444
445 // Gets the vtable of the entry waker
get_entry_waker_table<R>() -> &'static RawWakerVTable446 fn get_entry_waker_table<R>() -> &'static RawWakerVTable {
447 &RawWakerVTable::new(
448 clone_entry::<R>,
449 wake_entry::<R>,
450 wake_entry_ref::<R>,
451 drop_entry::<R>,
452 )
453 }
454
455 // Converts a entry reference into a Waker
entry_into_waker<R>(entry: &Arc<JoinEntry<R>>) -> Waker456 fn entry_into_waker<R>(entry: &Arc<JoinEntry<R>>) -> Waker {
457 let cpy = entry.clone();
458 let data = Arc::into_raw(cpy).cast::<()>();
459 unsafe { Waker::from_raw(RawWaker::new(data, get_entry_waker_table::<R>())) }
460 }
461
clone_entry<R>(data: *const ()) -> RawWaker462 unsafe fn clone_entry<R>(data: *const ()) -> RawWaker {
463 // First increment the arc counter
464 let entry = Arc::from_raw(data.cast::<JoinEntry<R>>());
465 mem::forget(entry.clone());
466 // Construct the new waker
467 let data = Arc::into_raw(entry).cast::<()>();
468 RawWaker::new(data, get_entry_waker_table::<R>())
469 }
470
wake_entry<R>(data: *const ())471 unsafe fn wake_entry<R>(data: *const ()) {
472 let entry = Arc::from_raw(data.cast::<JoinEntry<R>>());
473 JoinEntry::wake_by_ref(&entry);
474 }
475
wake_entry_ref<R>(data: *const ())476 unsafe fn wake_entry_ref<R>(data: *const ()) {
477 let entry = ManuallyDrop::new(Arc::from_raw(data.cast::<JoinEntry<R>>()));
478 JoinEntry::wake_by_ref(&entry);
479 }
480
drop_entry<R>(data: *const ())481 unsafe fn drop_entry<R>(data: *const ()) {
482 drop(Arc::from_raw(data.cast::<JoinEntry<R>>()))
483 }
484