1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 //! implement `split` fn for io, split it into `Reader` half and `Writer` half.
15
16 use std::io;
17 use std::io::IoSlice;
18 use std::pin::Pin;
19 use std::sync::Arc;
20 use std::task::{Context, Poll};
21
22 use ylong_runtime::io::{AsyncRead, AsyncWrite, ReadBuf};
23 use ylong_runtime::sync::{Mutex, MutexGuard};
24
25 macro_rules! ready {
26 ($e:expr $(,)?) => {
27 match $e {
28 std::task::Poll::Ready(t) => t,
29 std::task::Poll::Pending => return std::task::Poll::Pending,
30 }
31 };
32 }
33
34 pub(crate) struct Reader<T> {
35 inner: Arc<InnerLock<T>>,
36 }
37
38 pub(crate) struct Writer<T> {
39 inner: Arc<InnerLock<T>>,
40 }
41
42 struct InnerLock<T> {
43 stream: Mutex<T>,
44 is_write_vectored: bool,
45 }
46
47 struct StreamGuard<'a, T> {
48 inner: MutexGuard<'a, T>,
49 }
50
split<T>(stream: T) -> (Reader<T>, Writer<T>) where T: AsyncRead + AsyncWrite,51 pub(crate) fn split<T>(stream: T) -> (Reader<T>, Writer<T>)
52 where
53 T: AsyncRead + AsyncWrite,
54 {
55 let is_write_vectored = stream.is_write_vectored();
56 let inner = Arc::new(InnerLock {
57 stream: Mutex::new(stream),
58 is_write_vectored,
59 });
60
61 let rd = Reader {
62 inner: inner.clone(),
63 };
64
65 let wr = Writer { inner };
66
67 (rd, wr)
68 }
69
70 impl<T: AsyncRead> AsyncRead for Reader<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>71 fn poll_read(
72 self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 buf: &mut ReadBuf<'_>,
75 ) -> Poll<io::Result<()>> {
76 let mut guard = ready!(self.inner.get_lock(cx));
77 guard.stream().poll_read(cx, buf)
78 }
79 }
80
81 impl<T: AsyncWrite> AsyncWrite for Writer<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>82 fn poll_write(
83 self: Pin<&mut Self>,
84 cx: &mut Context<'_>,
85 buf: &[u8],
86 ) -> Poll<Result<usize, io::Error>> {
87 let mut inner = ready!(self.inner.get_lock(cx));
88 inner.stream().poll_write(cx, buf)
89 }
90
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<std::io::Result<usize>>91 fn poll_write_vectored(
92 self: Pin<&mut Self>,
93 cx: &mut Context<'_>,
94 bufs: &[IoSlice<'_>],
95 ) -> Poll<std::io::Result<usize>> {
96 let mut inner = ready!(self.inner.get_lock(cx));
97 inner.stream().poll_write_vectored(cx, bufs)
98 }
99
is_write_vectored(&self) -> bool100 fn is_write_vectored(&self) -> bool {
101 self.inner.is_write_vectored
102 }
103
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>104 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
105 let mut inner = ready!(self.inner.get_lock(cx));
106 inner.stream().poll_flush(cx)
107 }
108
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>109 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
110 let mut inner = ready!(self.inner.get_lock(cx));
111 inner.stream().poll_shutdown(cx)
112 }
113 }
114
115 impl<'a, T> StreamGuard<'a, T> {
stream(&mut self) -> Pin<&mut T>116 fn stream(&mut self) -> Pin<&mut T> {
117 // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual
118 // exclusion.
119 unsafe { Pin::new_unchecked(&mut *self.inner) }
120 }
121 }
122
123 impl<T> InnerLock<T> {
get_lock(&self, cx: &mut Context<'_>) -> Poll<StreamGuard<T>>124 fn get_lock(&self, cx: &mut Context<'_>) -> Poll<StreamGuard<T>> {
125 match self.stream.try_lock() {
126 Ok(guard) => Poll::Ready(StreamGuard { inner: guard }),
127 Err(_) => {
128 std::thread::yield_now();
129 cx.waker().wake_by_ref();
130
131 Poll::Pending
132 }
133 }
134 }
135 }
136