// Copyright (c) 2023 Huawei Device Co., Ltd. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! implement `split` fn for io, split it into `Reader` half and `Writer` half. use std::io; use std::io::IoSlice; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use ylong_runtime::io::{AsyncRead, AsyncWrite, ReadBuf}; use ylong_runtime::sync::{Mutex, MutexGuard}; macro_rules! ready { ($e:expr $(,)?) => { match $e { std::task::Poll::Ready(t) => t, std::task::Poll::Pending => return std::task::Poll::Pending, } }; } pub(crate) struct Reader { inner: Arc>, } pub(crate) struct Writer { inner: Arc>, } struct InnerLock { stream: Mutex, is_write_vectored: bool, } struct StreamGuard<'a, T> { inner: MutexGuard<'a, T>, } pub(crate) fn split(stream: T) -> (Reader, Writer) where T: AsyncRead + AsyncWrite, { let is_write_vectored = stream.is_write_vectored(); let inner = Arc::new(InnerLock { stream: Mutex::new(stream), is_write_vectored, }); let rd = Reader { inner: inner.clone(), }; let wr = Writer { inner }; (rd, wr) } impl AsyncRead for Reader { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let mut guard = ready!(self.inner.get_lock(cx)); guard.stream().poll_read(cx, buf) } } impl AsyncWrite for Writer { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let mut inner = ready!(self.inner.get_lock(cx)); inner.stream().poll_write(cx, buf) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { let mut inner = ready!(self.inner.get_lock(cx)); inner.stream().poll_write_vectored(cx, bufs) } fn is_write_vectored(&self) -> bool { self.inner.is_write_vectored } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut inner = ready!(self.inner.get_lock(cx)); inner.stream().poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut inner = ready!(self.inner.get_lock(cx)); inner.stream().poll_shutdown(cx) } } impl<'a, T> StreamGuard<'a, T> { fn stream(&mut self) -> Pin<&mut T> { // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual // exclusion. unsafe { Pin::new_unchecked(&mut *self.inner) } } } impl InnerLock { fn get_lock(&self, cx: &mut Context<'_>) -> Poll> { match self.stream.try_lock() { Ok(guard) => Poll::Ready(StreamGuard { inner: guard }), Err(_) => { std::thread::yield_now(); cx.waker().wake_by_ref(); Poll::Pending } } } }