1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::fmt::Debug;
15 use std::io::{self, IoSlice, IoSliceMut, Read, Write};
16 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
17 use std::os::unix::net;
18 use std::path::Path;
19 
20 use crate::source::Fd;
21 use crate::{Interest, Selector, Source, Token};
22 
23 /// A non-blocking UDS Stream between two local sockets.
24 pub struct UnixStream {
25     pub(crate) inner: net::UnixStream,
26 }
27 
28 impl UnixStream {
29     /// Connects to the specific socket.
30     ///
31     /// # Examples
32     /// ```no_run
33     /// use ylong_io::UnixStream;
34     ///
35     /// if let Ok(sock) = UnixStream::connect("/tmp/sock") {
36     ///     println!("socket connection succeeds");
37     /// };
38     /// ```
connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream>39     pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
40         super::socket::connect(path.as_ref()).map(UnixStream::from_std)
41     }
42 
43     /// Creates a new `UnixStream` from a standard `net::UnixStream`
44     ///
45     /// # Examples
46     /// ```no_run
47     /// use std::os::unix::net::UnixStream;
48     ///
49     /// use ylong_io::UnixStream as YlongUnixStream;
50     ///
51     /// if let Ok(stream) = UnixStream::connect("/path/to/the/socket") {
52     ///     println!("socket binds successfully");
53     ///     let stream = YlongUnixStream::from_std(stream);
54     /// };
55     /// ```
from_std(stream: net::UnixStream) -> UnixStream56     pub fn from_std(stream: net::UnixStream) -> UnixStream {
57         UnixStream { inner: stream }
58     }
59 
60     /// Creates an unnamed pair of connected sockets.
61     /// Returns two `UnixStream`s which are connected to each other.
62     ///
63     /// # Examples
64     /// ```no_run
65     /// use ylong_io::UnixStream;
66     ///
67     /// if let Ok((stream1, stream2)) = UnixStream::pair() {
68     ///     println!("unix socket pair created successfully");
69     /// };
70     /// ```
pair() -> io::Result<(UnixStream, UnixStream)>71     pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
72         super::socket::stream_pair().map(|(stream1, stream2)| {
73             (UnixStream::from_std(stream1), UnixStream::from_std(stream2))
74         })
75     }
76 
77     /// Creates a new independently owned handle to the underlying socket.
78     ///
79     /// # Examples
80     /// ```no_run
81     /// use ylong_io::UnixStream;
82     ///
83     /// fn test() -> std::io::Result<()> {
84     ///     let socket = UnixStream::connect("/tmp/sock")?;
85     ///     let sock_copy = socket.try_clone().expect("try_clone() fail");
86     ///     Ok(())
87     /// }
88     /// ```
try_clone(&self) -> io::Result<UnixStream>89     pub fn try_clone(&self) -> io::Result<UnixStream> {
90         Ok(Self::from_std(self.inner.try_clone()?))
91     }
92 
93     /// Returns the local socket address of this UnixStream.
94     ///
95     /// # Examples
96     /// ```no_run
97     /// use ylong_io::UnixStream;
98     ///
99     /// fn test() -> std::io::Result<()> {
100     ///     let socket = UnixStream::connect("/tmp/sock")?;
101     ///     let addr = socket.local_addr().expect("get local_addr() fail");
102     ///     Ok(())
103     /// }
104     /// ```
local_addr(&self) -> io::Result<net::SocketAddr>105     pub fn local_addr(&self) -> io::Result<net::SocketAddr> {
106         self.inner.local_addr()
107     }
108 
109     /// Returns the local socket address of this UnixStream's peer.
110     ///
111     /// # Examples
112     /// ```no_run
113     /// use ylong_io::UnixStream;
114     ///
115     /// fn test() -> std::io::Result<()> {
116     ///     let socket = UnixStream::connect("/tmp/sock")?;
117     ///     let addr = socket.peer_addr().expect("get peer_addr() fail");
118     ///     Ok(())
119     /// }
120     /// ```
peer_addr(&self) -> io::Result<net::SocketAddr>121     pub fn peer_addr(&self) -> io::Result<net::SocketAddr> {
122         self.inner.peer_addr()
123     }
124 
125     /// Returns the error of the `SO_ERROR` option.
126     ///
127     /// # Examples
128     /// ```no_run
129     /// use ylong_io::UnixStream;
130     ///
131     /// fn test() -> std::io::Result<()> {
132     ///     let socket = UnixStream::connect("/tmp/sock")?;
133     ///     if let Ok(Some(err)) = socket.take_error() {
134     ///         println!("get error: {err:?}");
135     ///     }
136     ///     Ok(())
137     /// }
138     /// ```
take_error(&self) -> io::Result<Option<io::Error>>139     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
140         self.inner.take_error()
141     }
142 
143     /// Shuts down this connection.
144     ///
145     /// # Examples
146     /// ```no_run
147     /// use std::net::Shutdown;
148     ///
149     /// use ylong_io::UnixStream;
150     ///
151     /// fn test() -> std::io::Result<()> {
152     ///     let socket = UnixStream::connect("/tmp/sock")?;
153     ///     socket.shutdown(Shutdown::Both).expect("shutdown() failed");
154     ///     Ok(())
155     /// }
156     /// ```
shutdown(&self, how: std::net::Shutdown) -> io::Result<()>157     pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
158         self.inner.shutdown(how)
159     }
160 }
161 
162 impl Read for UnixStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>163     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
164         self.inner.read(buf)
165     }
166 
read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize>167     fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
168         self.inner.read_vectored(bufs)
169     }
170 }
171 
172 impl Write for UnixStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>173     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
174         self.inner.write(buf)
175     }
176 
write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize>177     fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
178         self.inner.write_vectored(bufs)
179     }
180 
flush(&mut self) -> io::Result<()>181     fn flush(&mut self) -> io::Result<()> {
182         self.inner.flush()
183     }
184 }
185 impl Read for &UnixStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>186     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
187         let mut inner = &self.inner;
188         inner.read(buf)
189     }
190 
read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize>191     fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
192         let mut inner = &self.inner;
193         inner.read_vectored(bufs)
194     }
195 }
196 
197 impl Write for &UnixStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>198     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
199         let mut inner = &self.inner;
200         inner.write(buf)
201     }
202 
write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize>203     fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
204         let mut inner = &self.inner;
205         inner.write_vectored(bufs)
206     }
207 
flush(&mut self) -> io::Result<()>208     fn flush(&mut self) -> io::Result<()> {
209         let mut inner = &self.inner;
210         inner.flush()
211     }
212 }
213 
214 impl Debug for UnixStream {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result215     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216         self.inner.fmt(f)
217     }
218 }
219 
220 impl Source for UnixStream {
register( &mut self, selector: &Selector, token: Token, interests: Interest, ) -> io::Result<()>221     fn register(
222         &mut self,
223         selector: &Selector,
224         token: Token,
225         interests: Interest,
226     ) -> io::Result<()> {
227         selector.register(self.inner.as_raw_fd(), token, interests)
228     }
229 
deregister(&mut self, selector: &Selector) -> io::Result<()>230     fn deregister(&mut self, selector: &Selector) -> io::Result<()> {
231         selector.deregister(self.inner.as_raw_fd())
232     }
233 
get_fd(&self) -> Fd234     fn get_fd(&self) -> Fd {
235         self.inner.as_raw_fd()
236     }
237 }
238 
239 impl IntoRawFd for UnixStream {
into_raw_fd(self) -> RawFd240     fn into_raw_fd(self) -> RawFd {
241         self.inner.into_raw_fd()
242     }
243 }
244 
245 impl FromRawFd for UnixStream {
from_raw_fd(fd: RawFd) -> UnixStream246     unsafe fn from_raw_fd(fd: RawFd) -> UnixStream {
247         UnixStream::from_std(FromRawFd::from_raw_fd(fd))
248     }
249 }
250 
251 #[cfg(test)]
252 mod test {
253     use std::net::Shutdown;
254     use std::os::fd::{FromRawFd, IntoRawFd};
255 
256     use crate::UnixStream;
257 
258     /// UT for `UnixStream::pair`
259     ///
260     /// # Brief
261     /// 1. Create a pair of UnixStream
262     /// 2. Check if the peer address is correct
263     /// 3. Check if the local address is correct
264     /// 4. Shutdown both UnixStream
265     #[test]
ut_uds_stream_pair()266     fn ut_uds_stream_pair() {
267         let (sender, receiver) = UnixStream::pair().unwrap();
268         let sender2 = sender.try_clone().unwrap();
269 
270         let addr = sender2.local_addr().unwrap();
271         let fmt = format!("{addr:?}");
272         assert_eq!(&fmt, "(unnamed)");
273 
274         let addr = receiver.peer_addr().unwrap();
275         let fmt = format!("{addr:?}");
276         assert_eq!(&fmt, "(unnamed)");
277 
278         let fd = receiver.into_raw_fd();
279         let receiver2 = unsafe { UnixStream::from_raw_fd(fd) };
280         let addr = receiver2.local_addr().unwrap();
281         let fmt = format!("{addr:?}");
282         assert_eq!(&fmt, "(unnamed)");
283 
284         receiver2.shutdown(Shutdown::Both).unwrap();
285         sender.shutdown(Shutdown::Both).unwrap()
286     }
287 }
288