1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 //! implement `split` fn for io, split it into `Reader` half and `Writer` half.
15 
16 use std::io;
17 use std::io::IoSlice;
18 use std::pin::Pin;
19 use std::sync::Arc;
20 use std::task::{Context, Poll};
21 
22 use ylong_runtime::io::{AsyncRead, AsyncWrite, ReadBuf};
23 use ylong_runtime::sync::{Mutex, MutexGuard};
24 
25 macro_rules! ready {
26     ($e:expr $(,)?) => {
27         match $e {
28             std::task::Poll::Ready(t) => t,
29             std::task::Poll::Pending => return std::task::Poll::Pending,
30         }
31     };
32 }
33 
34 pub(crate) struct Reader<T> {
35     inner: Arc<InnerLock<T>>,
36 }
37 
38 pub(crate) struct Writer<T> {
39     inner: Arc<InnerLock<T>>,
40 }
41 
42 struct InnerLock<T> {
43     stream: Mutex<T>,
44     is_write_vectored: bool,
45 }
46 
47 struct StreamGuard<'a, T> {
48     inner: MutexGuard<'a, T>,
49 }
50 
split<T>(stream: T) -> (Reader<T>, Writer<T>) where T: AsyncRead + AsyncWrite,51 pub(crate) fn split<T>(stream: T) -> (Reader<T>, Writer<T>)
52 where
53     T: AsyncRead + AsyncWrite,
54 {
55     let is_write_vectored = stream.is_write_vectored();
56     let inner = Arc::new(InnerLock {
57         stream: Mutex::new(stream),
58         is_write_vectored,
59     });
60 
61     let rd = Reader {
62         inner: inner.clone(),
63     };
64 
65     let wr = Writer { inner };
66 
67     (rd, wr)
68 }
69 
70 impl<T: AsyncRead> AsyncRead for Reader<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>71     fn poll_read(
72         self: Pin<&mut Self>,
73         cx: &mut Context<'_>,
74         buf: &mut ReadBuf<'_>,
75     ) -> Poll<io::Result<()>> {
76         let mut guard = ready!(self.inner.get_lock(cx));
77         guard.stream().poll_read(cx, buf)
78     }
79 }
80 
81 impl<T: AsyncWrite> AsyncWrite for Writer<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>82     fn poll_write(
83         self: Pin<&mut Self>,
84         cx: &mut Context<'_>,
85         buf: &[u8],
86     ) -> Poll<Result<usize, io::Error>> {
87         let mut inner = ready!(self.inner.get_lock(cx));
88         inner.stream().poll_write(cx, buf)
89     }
90 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<std::io::Result<usize>>91     fn poll_write_vectored(
92         self: Pin<&mut Self>,
93         cx: &mut Context<'_>,
94         bufs: &[IoSlice<'_>],
95     ) -> Poll<std::io::Result<usize>> {
96         let mut inner = ready!(self.inner.get_lock(cx));
97         inner.stream().poll_write_vectored(cx, bufs)
98     }
99 
is_write_vectored(&self) -> bool100     fn is_write_vectored(&self) -> bool {
101         self.inner.is_write_vectored
102     }
103 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>104     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
105         let mut inner = ready!(self.inner.get_lock(cx));
106         inner.stream().poll_flush(cx)
107     }
108 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>109     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
110         let mut inner = ready!(self.inner.get_lock(cx));
111         inner.stream().poll_shutdown(cx)
112     }
113 }
114 
115 impl<'a, T> StreamGuard<'a, T> {
stream(&mut self) -> Pin<&mut T>116     fn stream(&mut self) -> Pin<&mut T> {
117         // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual
118         // exclusion.
119         unsafe { Pin::new_unchecked(&mut *self.inner) }
120     }
121 }
122 
123 impl<T> InnerLock<T> {
get_lock(&self, cx: &mut Context<'_>) -> Poll<StreamGuard<T>>124     fn get_lock(&self, cx: &mut Context<'_>) -> Poll<StreamGuard<T>> {
125         match self.stream.try_lock() {
126             Ok(guard) => Poll::Ready(StreamGuard { inner: guard }),
127             Err(_) => {
128                 std::thread::yield_now();
129                 cx.waker().wake_by_ref();
130 
131                 Poll::Pending
132             }
133         }
134     }
135 }
136