1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::future::Future;
15 use std::pin::Pin;
16 use std::task::{Context, Poll};
17 
18 macro_rules! poll_return_if_err {
19     ($fut: expr, $is_pending: expr, $cx: expr) => {
20         match $fut.as_mut().poll($cx) {
21             Poll::Pending => $is_pending = true,
22             Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
23             Poll::Ready(_) => {}
24         }
25     };
26 }
27 
try_join3<F1, F2, F3, R1, R2, R3, E>( fut1: F1, fut2: F2, fut3: F3, ) -> Result<(R1, R2, R3), E> where F1: Future<Output = Result<R1, E>>, F2: Future<Output = Result<R2, E>>, F3: Future<Output = Result<R3, E>>,28 pub(crate) async fn try_join3<F1, F2, F3, R1, R2, R3, E>(
29     fut1: F1,
30     fut2: F2,
31     fut3: F3,
32 ) -> Result<(R1, R2, R3), E>
33 where
34     F1: Future<Output = Result<R1, E>>,
35     F2: Future<Output = Result<R2, E>>,
36     F3: Future<Output = Result<R3, E>>,
37 {
38     let mut fut1 = future_done(fut1);
39     let mut fut2 = future_done(fut2);
40     let mut fut3 = future_done(fut3);
41 
42     crate::futures::poll_fn(move |cx| {
43         let mut is_pending = false;
44 
45         let mut fut1 = unsafe { Pin::new_unchecked(&mut fut1) };
46         poll_return_if_err!(fut1, is_pending, cx);
47 
48         let mut fut2 = unsafe { Pin::new_unchecked(&mut fut2) };
49         poll_return_if_err!(fut2, is_pending, cx);
50 
51         let mut fut3 = unsafe { Pin::new_unchecked(&mut fut3) };
52         poll_return_if_err!(fut3, is_pending, cx);
53 
54         if is_pending {
55             Poll::Pending
56         } else {
57             // All fut should have a ready(Ok(res)) result here
58             Poll::Ready(Ok((
59                 fut1.take_output().unwrap_or_else(|_| unreachable!()),
60                 fut2.take_output().unwrap_or_else(|_| unreachable!()),
61                 fut3.take_output().unwrap_or_else(|_| unreachable!()),
62             )))
63         }
64     })
65     .await
66 }
67 
68 pub(crate) enum FutureDone<F: Future> {
69     Pending(F),
70     Ready(F::Output),
71     None,
72 }
73 
future_done<F: Future>(future: F) -> FutureDone<F>74 pub(crate) fn future_done<F: Future>(future: F) -> FutureDone<F> {
75     FutureDone::Pending(future)
76 }
77 
78 impl<F: Future + Unpin> Unpin for FutureDone<F> {}
79 
80 impl<F: Future> FutureDone<F> {
take_output(self: Pin<&mut Self>) -> F::Output81     pub(crate) fn take_output(self: Pin<&mut Self>) -> F::Output {
82         // Safety: inner data never move.
83         unsafe {
84             let inner = self.get_unchecked_mut();
85             if let FutureDone::Ready(output) = std::mem::replace(inner, FutureDone::None) {
86                 return output;
87             }
88             unreachable!()
89         }
90     }
91 }
92 
93 impl<E, R, F: Future<Output = Result<R, E>>> Future for FutureDone<F> {
94     type Output = Result<(), E>;
95 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>96     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97         // Safety: inner data never move.
98         unsafe {
99             match self.as_mut().get_unchecked_mut() {
100                 FutureDone::Pending(fut) => match Pin::new_unchecked(fut).poll(cx) {
101                     Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
102                     Poll::Ready(Ok(res)) => {
103                         self.set(FutureDone::Ready(Ok(res)));
104                         Poll::Ready(Ok(()))
105                     }
106                     Poll::Pending => Poll::Pending,
107                 },
108                 FutureDone::Ready(_) => Poll::Ready(Ok(())),
109                 FutureDone::None => panic!("FutureDone output has gone"),
110             }
111         }
112     }
113 }
114 
115 #[cfg(test)]
116 mod test {
117     use std::future::Future;
118     use std::pin::Pin;
119     use std::task::Poll;
120 
121     use crate::process::try_join3::{future_done, try_join3};
122     /// UT test cases for `try_join()`.
123     ///
124     /// # Brief
125     /// 1. Create 3 future with 1 return err.
126     /// 2. try_join() return error.
127     #[test]
ut_try_join_error_test()128     fn ut_try_join_error_test() {
129         async fn ok() -> Result<(), &'static str> {
130             Ok(())
131         }
132         async fn err() -> Result<(), &'static str> {
133             Err("test")
134         }
135         let handle = crate::spawn(async {
136             let fut1 = err();
137             let fut2 = ok();
138             let fut3 = ok();
139             let res = try_join3(fut1, fut2, fut3).await;
140             assert!(res.is_err());
141 
142             let fut1 = ok();
143             let fut2 = err();
144             let fut3 = ok();
145             let res = try_join3(fut1, fut2, fut3).await;
146             assert!(res.is_err());
147 
148             let fut1 = ok();
149             let fut2 = ok();
150             let fut3 = err();
151             let res = try_join3(fut1, fut2, fut3).await;
152             assert!(res.is_err());
153         });
154         crate::block_on(handle).unwrap();
155     }
156 
157     /// UT test cases for `FutureDone`.
158     ///
159     /// # Brief
160     /// 1. Create FutureDone with future_done().
161     /// 2. Check the result.
162     #[test]
ut_future_done_test()163     fn ut_future_done_test() {
164         let handle = crate::spawn(async {
165             let fut = async { Ok(1) };
166             let mut fut = future_done(fut);
167 
168             crate::futures::poll_fn(move |cx| {
169                 let mut fut = unsafe { Pin::new_unchecked(&mut fut) };
170 
171                 if fut.as_mut().poll(cx).is_pending() {
172                     Poll::Pending
173                 } else {
174                     let output: Result<i32, i32> = fut.as_mut().take_output();
175                     assert!(output.is_ok());
176                     Poll::Ready(output.unwrap())
177                 }
178             })
179             .await;
180         });
181         crate::block_on(handle).unwrap();
182     }
183 }
184