1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::future::Future;
15 use std::pin::Pin;
16 use std::task::{Context, Poll};
17
18 macro_rules! poll_return_if_err {
19 ($fut: expr, $is_pending: expr, $cx: expr) => {
20 match $fut.as_mut().poll($cx) {
21 Poll::Pending => $is_pending = true,
22 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
23 Poll::Ready(_) => {}
24 }
25 };
26 }
27
try_join3<F1, F2, F3, R1, R2, R3, E>( fut1: F1, fut2: F2, fut3: F3, ) -> Result<(R1, R2, R3), E> where F1: Future<Output = Result<R1, E>>, F2: Future<Output = Result<R2, E>>, F3: Future<Output = Result<R3, E>>,28 pub(crate) async fn try_join3<F1, F2, F3, R1, R2, R3, E>(
29 fut1: F1,
30 fut2: F2,
31 fut3: F3,
32 ) -> Result<(R1, R2, R3), E>
33 where
34 F1: Future<Output = Result<R1, E>>,
35 F2: Future<Output = Result<R2, E>>,
36 F3: Future<Output = Result<R3, E>>,
37 {
38 let mut fut1 = future_done(fut1);
39 let mut fut2 = future_done(fut2);
40 let mut fut3 = future_done(fut3);
41
42 crate::futures::poll_fn(move |cx| {
43 let mut is_pending = false;
44
45 let mut fut1 = unsafe { Pin::new_unchecked(&mut fut1) };
46 poll_return_if_err!(fut1, is_pending, cx);
47
48 let mut fut2 = unsafe { Pin::new_unchecked(&mut fut2) };
49 poll_return_if_err!(fut2, is_pending, cx);
50
51 let mut fut3 = unsafe { Pin::new_unchecked(&mut fut3) };
52 poll_return_if_err!(fut3, is_pending, cx);
53
54 if is_pending {
55 Poll::Pending
56 } else {
57 // All fut should have a ready(Ok(res)) result here
58 Poll::Ready(Ok((
59 fut1.take_output().unwrap_or_else(|_| unreachable!()),
60 fut2.take_output().unwrap_or_else(|_| unreachable!()),
61 fut3.take_output().unwrap_or_else(|_| unreachable!()),
62 )))
63 }
64 })
65 .await
66 }
67
68 pub(crate) enum FutureDone<F: Future> {
69 Pending(F),
70 Ready(F::Output),
71 None,
72 }
73
future_done<F: Future>(future: F) -> FutureDone<F>74 pub(crate) fn future_done<F: Future>(future: F) -> FutureDone<F> {
75 FutureDone::Pending(future)
76 }
77
78 impl<F: Future + Unpin> Unpin for FutureDone<F> {}
79
80 impl<F: Future> FutureDone<F> {
take_output(self: Pin<&mut Self>) -> F::Output81 pub(crate) fn take_output(self: Pin<&mut Self>) -> F::Output {
82 // Safety: inner data never move.
83 unsafe {
84 let inner = self.get_unchecked_mut();
85 if let FutureDone::Ready(output) = std::mem::replace(inner, FutureDone::None) {
86 return output;
87 }
88 unreachable!()
89 }
90 }
91 }
92
93 impl<E, R, F: Future<Output = Result<R, E>>> Future for FutureDone<F> {
94 type Output = Result<(), E>;
95
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>96 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97 // Safety: inner data never move.
98 unsafe {
99 match self.as_mut().get_unchecked_mut() {
100 FutureDone::Pending(fut) => match Pin::new_unchecked(fut).poll(cx) {
101 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
102 Poll::Ready(Ok(res)) => {
103 self.set(FutureDone::Ready(Ok(res)));
104 Poll::Ready(Ok(()))
105 }
106 Poll::Pending => Poll::Pending,
107 },
108 FutureDone::Ready(_) => Poll::Ready(Ok(())),
109 FutureDone::None => panic!("FutureDone output has gone"),
110 }
111 }
112 }
113 }
114
115 #[cfg(test)]
116 mod test {
117 use std::future::Future;
118 use std::pin::Pin;
119 use std::task::Poll;
120
121 use crate::process::try_join3::{future_done, try_join3};
122 /// UT test cases for `try_join()`.
123 ///
124 /// # Brief
125 /// 1. Create 3 future with 1 return err.
126 /// 2. try_join() return error.
127 #[test]
ut_try_join_error_test()128 fn ut_try_join_error_test() {
129 async fn ok() -> Result<(), &'static str> {
130 Ok(())
131 }
132 async fn err() -> Result<(), &'static str> {
133 Err("test")
134 }
135 let handle = crate::spawn(async {
136 let fut1 = err();
137 let fut2 = ok();
138 let fut3 = ok();
139 let res = try_join3(fut1, fut2, fut3).await;
140 assert!(res.is_err());
141
142 let fut1 = ok();
143 let fut2 = err();
144 let fut3 = ok();
145 let res = try_join3(fut1, fut2, fut3).await;
146 assert!(res.is_err());
147
148 let fut1 = ok();
149 let fut2 = ok();
150 let fut3 = err();
151 let res = try_join3(fut1, fut2, fut3).await;
152 assert!(res.is_err());
153 });
154 crate::block_on(handle).unwrap();
155 }
156
157 /// UT test cases for `FutureDone`.
158 ///
159 /// # Brief
160 /// 1. Create FutureDone with future_done().
161 /// 2. Check the result.
162 #[test]
ut_future_done_test()163 fn ut_future_done_test() {
164 let handle = crate::spawn(async {
165 let fut = async { Ok(1) };
166 let mut fut = future_done(fut);
167
168 crate::futures::poll_fn(move |cx| {
169 let mut fut = unsafe { Pin::new_unchecked(&mut fut) };
170
171 if fut.as_mut().poll(cx).is_pending() {
172 Poll::Pending
173 } else {
174 let output: Result<i32, i32> = fut.as_mut().take_output();
175 assert!(output.is_ok());
176 Poll::Ready(output.unwrap())
177 }
178 })
179 .await;
180 });
181 crate::block_on(handle).unwrap();
182 }
183 }
184