1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::fmt;
15 use std::io::{Error, IoSlice, Read, Result, Write};
16 use std::net::Shutdown;
17 use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
18 use std::os::unix::net;
19 use std::path::Path;
20 use std::pin::Pin;
21 use std::task::{Context, Poll};
22 
23 use ylong_io::{Interest, Source};
24 
25 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
26 use crate::net::{AsyncSource, Ready};
27 
28 /// A non-blocking UDS Stream between two local sockets.
29 pub struct UnixStream {
30     source: AsyncSource<ylong_io::UnixStream>,
31 }
32 
33 impl UnixStream {
34     /// Creates a new `UnixStream` from `ylong_io::UnixStream`
new(stream: ylong_io::UnixStream) -> Result<Self>35     pub(crate) fn new(stream: ylong_io::UnixStream) -> Result<Self> {
36         Ok(UnixStream {
37             source: AsyncSource::new(stream, None)?,
38         })
39     }
40 
41     /// Opens a UDS connection to a remote host asynchronously.
42     ///
43     /// # Example
44     /// ```no_run
45     /// use std::io;
46     ///
47     /// use ylong_runtime::net::UnixStream;
48     ///
49     /// async fn io_func() -> io::Result<()> {
50     ///     let mut stream = UnixStream::connect("/tmp/sock").await?;
51     ///     Ok(())
52     /// }
53     /// ```
connect<P: AsRef<Path>>(path: P) -> Result<UnixStream>54     pub async fn connect<P: AsRef<Path>>(path: P) -> Result<UnixStream> {
55         let stream = UnixStream::new(ylong_io::UnixStream::connect(path)?)?;
56 
57         stream
58             .source
59             .async_process(
60                 // Wait until the stream is writable
61                 Interest::WRITABLE,
62                 || Ok(()),
63             )
64             .await?;
65 
66         if let Some(e) = stream.source.take_error()? {
67             return Err(e);
68         }
69 
70         Ok(stream)
71     }
72 
73     /// Creates new `UnixStream` from a `std::os::unix::net::UnixStream`.
74     ///
75     /// # Examples
76     /// ```no_run
77     /// use std::error::Error;
78     /// use std::os::unix::net::UnixStream as StdUnixStream;
79     ///
80     /// use ylong_runtime::net::UnixStream;
81     ///
82     /// async fn dox() -> Result<(), Box<dyn Error>> {
83     ///     let std_stream = StdUnixStream::connect("/socket/path")?;
84     ///     std_stream.set_nonblocking(true)?;
85     ///     let stream = UnixStream::from_std(std_stream)?;
86     ///     Ok(())
87     /// }
88     /// ```
from_std(listener: net::UnixStream) -> Result<UnixStream>89     pub fn from_std(listener: net::UnixStream) -> Result<UnixStream> {
90         let stream = ylong_io::UnixStream::from_std(listener);
91         Ok(UnixStream {
92             source: AsyncSource::new(stream, None)?,
93         })
94     }
95 
96     /// Creates an unnamed pair of connected sockets.
97     /// Returns two `UnixStream`s which are connected to each other.
98     ///
99     /// # Examples
100     /// ```no_run
101     /// use ylong_runtime::net::UnixStream;
102     ///
103     /// let (stream1, stream2) = match UnixStream::pair() {
104     ///     Ok((stream1, stream2)) => (stream1, stream2),
105     ///     Err(err) => {
106     ///         println!("Couldn't create a pair of sockets: {err:?}");
107     ///         return;
108     ///     }
109     /// };
110     /// ```
pair() -> Result<(UnixStream, UnixStream)>111     pub fn pair() -> Result<(UnixStream, UnixStream)> {
112         let (stream1, stream2) = ylong_io::UnixStream::pair()?;
113         let stream1 = UnixStream::new(stream1)?;
114         let stream2 = UnixStream::new(stream2)?;
115 
116         Ok((stream1, stream2))
117     }
118 
119     /// Waits for any of the requested ready states.
120     ///
121     /// # Examples
122     /// ```no_run
123     /// use std::error::Error;
124     /// use std::io::ErrorKind;
125     ///
126     /// use ylong_io::Interest;
127     /// use ylong_runtime::net::UnixStream;
128     ///
129     /// async fn test() -> Result<(), Box<dyn Error>> {
130     ///     let stream = UnixStream::connect("/socket/path").await?;
131     ///
132     ///     loop {
133     ///         let ready = stream
134     ///             .ready(Interest::READABLE | Interest::WRITABLE)
135     ///             .await?;
136     ///
137     ///         if ready.is_readable() {
138     ///             let mut data = vec![0; 128];
139     ///             match stream.try_read(&mut data) {
140     ///                 Ok(n) => {
141     ///                     println!("read {} bytes", n);
142     ///                 }
143     ///                 Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
144     ///                     continue;
145     ///                 }
146     ///                 Err(e) => {
147     ///                     return Err(e.into());
148     ///                 }
149     ///             }
150     ///         }
151     ///
152     ///         if ready.is_writable() {
153     ///             match stream.try_write(b"hello world") {
154     ///                 Ok(n) => {
155     ///                     println!("write {} bytes", n);
156     ///                 }
157     ///                 Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
158     ///                     continue;
159     ///                 }
160     ///                 Err(e) => {
161     ///                     return Err(e.into());
162     ///                 }
163     ///             }
164     ///         }
165     ///     }
166     /// }
167     /// ```
ready(&self, interest: Interest) -> Result<Ready>168     pub async fn ready(&self, interest: Interest) -> Result<Ready> {
169         let event = self.source.entry.readiness(interest).await?;
170         Ok(event.ready)
171     }
172 
173     /// Waits for `Interest::READABLE` requested ready states.
174     ///
175     /// # Examples
176     /// ```no_run
177     /// use std::error::Error;
178     /// use std::io::ErrorKind;
179     ///
180     /// use ylong_io::Interest;
181     /// use ylong_runtime::net::UnixStream;
182     ///
183     /// async fn test() -> Result<(), Box<dyn Error>> {
184     ///     let stream = UnixStream::connect("/socket/path").await?;
185     ///     loop {
186     ///         stream.readable().await?;
187     ///         let mut data = vec![0; 128];
188     ///         match stream.try_read(&mut data) {
189     ///             Ok(n) => {
190     ///                 println!("read {} bytes", n);
191     ///             }
192     ///             Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
193     ///                 continue;
194     ///             }
195     ///             Err(e) => {
196     ///                 return Err(e.into());
197     ///             }
198     ///         }
199     ///     }
200     /// }
201     /// ```
readable(&self) -> Result<()>202     pub async fn readable(&self) -> Result<()> {
203         self.ready(Interest::READABLE).await?;
204         Ok(())
205     }
206 
207     /// Trys to read stream.
208     /// This method will immediately return the result.
209     /// If it is currently unavailable, it will return `WouldBlock`
210     ///
211     /// # Examples
212     /// ```no_run
213     /// use std::error::Error;
214     /// use std::io::ErrorKind;
215     ///
216     /// use ylong_io::Interest;
217     /// use ylong_runtime::net::UnixStream;
218     ///
219     /// async fn test() -> Result<(), Box<dyn Error>> {
220     ///     let stream = UnixStream::connect("/socket/path").await?;
221     ///     loop {
222     ///         stream.readable().await?;
223     ///         let mut data = vec![0; 128];
224     ///         match stream.try_read(&mut data) {
225     ///             Ok(n) => {
226     ///                 println!("read {} bytes", n);
227     ///             }
228     ///             Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
229     ///                 continue;
230     ///             }
231     ///             Err(e) => {
232     ///                 return Err(e.into());
233     ///             }
234     ///         }
235     ///     }
236     /// }
237     /// ```
try_read(&self, buf: &mut [u8]) -> Result<usize>238     pub fn try_read(&self, buf: &mut [u8]) -> Result<usize> {
239         self.source
240             .try_io(Interest::READABLE, || (&*self.source).read(buf))
241     }
242 
243     /// Waits for `Interest::WRITABLE` requested ready states.
244     ///
245     /// # Examples
246     /// ```no_run
247     /// use std::error::Error;
248     /// use std::io::ErrorKind;
249     ///
250     /// use ylong_io::Interest;
251     /// use ylong_runtime::net::UnixStream;
252     ///
253     /// async fn test() -> Result<(), Box<dyn Error>> {
254     ///     let stream = UnixStream::connect("/socket/path").await?;
255     ///     loop {
256     ///         stream.writable().await?;
257     ///         match stream.try_write(b"hello world") {
258     ///             Ok(n) => {
259     ///                 println!("write {} bytes", n);
260     ///             }
261     ///             Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
262     ///                 continue;
263     ///             }
264     ///             Err(e) => {
265     ///                 return Err(e.into());
266     ///             }
267     ///         }
268     ///     }
269     /// }
270     /// ```
writable(&self) -> Result<()>271     pub async fn writable(&self) -> Result<()> {
272         self.ready(Interest::WRITABLE).await?;
273         Ok(())
274     }
275 
276     /// Trys to write stream.
277     /// This method will immediately return the result.
278     /// If it is currently unavailable, it will return `WouldBlock`
279     ///
280     /// # Examples
281     /// ```no_run
282     /// use std::error::Error;
283     /// use std::io::ErrorKind;
284     ///
285     /// use ylong_io::Interest;
286     /// use ylong_runtime::net::UnixStream;
287     ///
288     /// async fn test() -> Result<(), Box<dyn Error>> {
289     ///     let stream = UnixStream::connect("/socket/path").await?;
290     ///     loop {
291     ///         stream.writable().await?;
292     ///         match stream.try_write(b"hello world") {
293     ///             Ok(n) => {
294     ///                 println!("write {} bytes", n);
295     ///             }
296     ///             Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
297     ///                 continue;
298     ///             }
299     ///             Err(e) => {
300     ///                 return Err(e.into());
301     ///             }
302     ///         }
303     ///     }
304     /// }
305     /// ```
try_write(&self, buf: &[u8]) -> Result<usize>306     pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
307         self.source
308             .try_io(Interest::WRITABLE, || (&*self.source).write(buf))
309     }
310 
311     /// Returns the error of the `SO_ERROR` option.
312     ///
313     /// # Examples
314     /// ```no_run
315     /// use std::io::Result;
316     ///
317     /// use ylong_runtime::net::UnixStream;
318     ///
319     /// async fn test() -> Result<()> {
320     ///     let socket = UnixStream::connect("/tmp/sock").await?;
321     ///     if let Ok(Some(err)) = socket.take_error() {
322     ///         println!("get error: {err:?}");
323     ///     }
324     ///     Ok(())
325     /// }
326     /// ```
take_error(&self) -> Result<Option<Error>>327     pub fn take_error(&self) -> Result<Option<Error>> {
328         self.source.take_error()
329     }
330 
331     /// Shutdown UnixStream
shutdown(&self, how: Shutdown) -> Result<()>332     pub fn shutdown(&self, how: Shutdown) -> Result<()> {
333         self.source.shutdown(how)
334     }
335 }
336 
337 impl AsyncRead for UnixStream {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<Result<()>>338     fn poll_read(
339         self: Pin<&mut Self>,
340         cx: &mut Context<'_>,
341         buf: &mut ReadBuf<'_>,
342     ) -> Poll<Result<()>> {
343         self.source.poll_read(cx, buf)
344     }
345 }
346 
347 impl AsyncWrite for UnixStream {
poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>348     fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
349         self.source.poll_write(cx, buf)
350     }
351 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize>>352     fn poll_write_vectored(
353         self: Pin<&mut Self>,
354         cx: &mut Context<'_>,
355         bufs: &[IoSlice<'_>],
356     ) -> Poll<Result<usize>> {
357         self.source.poll_write_vectored(cx, bufs)
358     }
359 
is_write_vectored(&self) -> bool360     fn is_write_vectored(&self) -> bool {
361         true
362     }
363 
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>>364     fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
365         Poll::Ready(Ok(()))
366     }
367 
poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>>368     fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
369         self.source.shutdown(std::net::Shutdown::Write)?;
370         Poll::Ready(Ok(()))
371     }
372 }
373 
374 impl fmt::Debug for UnixStream {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result375     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376         self.source.fmt(f)
377     }
378 }
379 
380 impl AsRawFd for UnixStream {
as_raw_fd(&self) -> RawFd381     fn as_raw_fd(&self) -> RawFd {
382         self.source.get_fd()
383     }
384 }
385 
386 impl AsFd for UnixStream {
as_fd(&self) -> BorrowedFd<'_>387     fn as_fd(&self) -> BorrowedFd<'_> {
388         unsafe { BorrowedFd::borrow_raw(self.as_raw_fd()) }
389     }
390 }
391 
392 #[cfg(test)]
393 mod test {
394     use std::io;
395     use std::os::fd::{AsFd, AsRawFd};
396 
397     use crate::io::{AsyncReadExt, AsyncWriteExt};
398     use crate::net::UnixStream;
399 
400     /// Uds UnixStream test case.
401     ///
402     /// # Title
403     /// ut_uds_stream_baisc_test
404     ///
405     /// # Brief
406     /// 1. Create a std UnixStream with `pair()`.
407     /// 2. Convert std UnixStream to Ylong_runtime UnixStream.
408     /// 3. Check result is correct.
409     #[test]
ut_uds_stream_baisc_test()410     fn ut_uds_stream_baisc_test() {
411         let (stream, _) = std::os::unix::net::UnixStream::pair().unwrap();
412         let handle = crate::spawn(async {
413             let res = UnixStream::from_std(stream);
414             assert!(res.is_ok());
415             let stream = res.unwrap();
416             assert!(stream.as_fd().as_raw_fd() >= 0);
417             assert!(stream.as_raw_fd() >= 0);
418             assert!(stream.take_error().is_ok());
419         });
420         crate::block_on(handle).unwrap();
421     }
422 
423     /// uds UnixStream test case.
424     ///
425     /// # Title
426     /// ut_uds_stream_pair_test
427     ///
428     /// # Brief
429     /// 1. Creates a server and a client with `pair()`.
430     /// 2. Server Sends message and client recv it.
431     #[test]
ut_uds_stream_pair_test()432     fn ut_uds_stream_pair_test() {
433         let handle = crate::spawn(async {
434             let (mut server, mut client) = UnixStream::pair().unwrap();
435 
436             server.write_all(b"hello").await.unwrap();
437             server.flush().await.unwrap();
438 
439             let mut read_buf = [0_u8; 5];
440             let len = client.read(&mut read_buf).await.unwrap();
441             assert_eq!(std::str::from_utf8(&read_buf).unwrap(), "hello".to_string());
442             assert_eq!(len, "hello".len());
443         });
444         crate::block_on(handle).unwrap();
445     }
446 
447     /// Uds UnixStream try_xxx() test case.
448     ///
449     /// # Title
450     /// ut_uds_stream_try_test
451     ///
452     /// # Brief
453     /// 1. Creates a server and a client with `pair()`.
454     /// 2. Server send message with `writable()` and `try_write()`.
455     /// 3. Client receive message with `readable()` and `try_read()`.
456     /// 4. Check result is correct.
457     #[test]
ut_uds_stream_try_test()458     fn ut_uds_stream_try_test() {
459         let handle = crate::spawn(async {
460             let (server, client) = UnixStream::pair().unwrap();
461             loop {
462                 server.writable().await.unwrap();
463                 match server.try_write(b"hello") {
464                     Ok(n) => {
465                         assert_eq!(n, "hello".len());
466                         break;
467                     }
468                     Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
469                     Err(e) => panic!("{e:?}"),
470                 }
471             }
472             loop {
473                 client.readable().await.unwrap();
474                 let mut data = vec![0; 5];
475                 match client.try_read(&mut data) {
476                     Ok(n) => {
477                         assert_eq!(n, "hello".len());
478                         assert_eq!(std::str::from_utf8(&data).unwrap(), "hello".to_string());
479                         break;
480                     }
481                     Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
482                     Err(e) => panic!("{e:?}"),
483                 }
484             }
485         });
486         crate::block_on(handle).unwrap();
487     }
488 }
489