1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::fmt::Formatter;
15 use std::io::{IoSlice, IoSliceMut, Read, Write};
16 use std::net::{Shutdown, SocketAddr};
17 use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
18 use std::time::Duration;
19 use std::{fmt, io, net};
20 
21 use crate::source::Fd;
22 use crate::sys::windows::tcp::socket::{get_sock_linger, set_sock_linger};
23 use crate::sys::windows::tcp::TcpSocket;
24 use crate::sys::NetState;
25 use crate::{Interest, Selector, Source, Token};
26 
27 /// A non-blocking TCP Stream between a local socket and a remote socket.
28 pub struct TcpStream {
29     /// Raw TCP socket
30     pub(crate) inner: net::TcpStream,
31     /// State is None if the socket has not been Registered.
32     pub(crate) state: NetState,
33 }
34 
35 impl TcpStream {
36     /// Connects address to form TcpStream
37     ///
38     /// # Examples
39     ///
40     /// ```no_run
41     /// let addr = "127.0.0.1:1234".parse().unwrap();
42     /// let stream = ylong_io::TcpStream::connect(addr).unwrap();
43     /// ```
connect(addr: SocketAddr) -> io::Result<TcpStream>44     pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
45         let socket = TcpSocket::new_socket(addr)?;
46         let stream = unsafe { TcpStream::from_raw_socket(socket.as_raw_socket() as _) };
47 
48         socket.connect(addr)?;
49         Ok(stream)
50     }
51 
52     /// Creates `TcpStream` from raw TcpStream.
from_std(stream: net::TcpStream) -> TcpStream53     pub fn from_std(stream: net::TcpStream) -> TcpStream {
54         TcpStream {
55             inner: stream,
56             state: NetState::new(),
57         }
58     }
59 
60     /// Clones the TcpStream.
try_clone(&self) -> io::Result<Self>61     pub fn try_clone(&self) -> io::Result<Self> {
62         Ok(TcpStream {
63             inner: self.inner.try_clone()?,
64             state: self.state.clone(),
65         })
66     }
67 
68     /// Returns the socket address of the local half of this TCP connection.
69     ///
70     /// # Examples
71     ///
72     /// ```no_run
73     /// use std::net::{IpAddr, Ipv4Addr};
74     ///
75     /// use ylong_io::TcpStream;
76     ///
77     /// let addr = "127.0.0.1:1234".parse().unwrap();
78     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
79     /// assert_eq!(
80     ///     stream.local_addr().unwrap().ip(),
81     ///     IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
82     /// );
83     /// ```
local_addr(&self) -> io::Result<SocketAddr>84     pub fn local_addr(&self) -> io::Result<SocketAddr> {
85         self.inner.local_addr()
86     }
87 
88     /// Returns the socket address of the remote half of this TCP connection.
89     ///
90     /// # Examples
91     ///
92     /// ```no_run
93     /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
94     ///
95     /// use ylong_io::TcpStream;
96     ///
97     /// let addr = "127.0.0.1:1234".parse().unwrap();
98     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
99     /// assert_eq!(
100     ///     stream.peer_addr().unwrap(),
101     ///     SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234))
102     /// );
103     /// ```
peer_addr(&self) -> io::Result<SocketAddr>104     pub fn peer_addr(&self) -> io::Result<SocketAddr> {
105         self.inner.peer_addr()
106     }
107 
108     /// Shuts down the read, write, or both halves of this connection.
109     ///
110     /// # Examples
111     ///
112     /// ```no_run
113     /// use std::net::Shutdown;
114     ///
115     /// use ylong_io::TcpStream;
116     ///
117     /// let addr = "127.0.0.1:1234".parse().unwrap();
118     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
119     /// stream
120     ///     .shutdown(Shutdown::Both)
121     ///     .expect("shutdown call failed");
122     /// ```
shutdown(&self, how: Shutdown) -> io::Result<()>123     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
124         self.inner.shutdown(how)
125     }
126 
127     /// Sets the value of the `TCP_NODELAY`.
128     ///
129     /// # Examples
130     ///
131     /// ```no_run
132     /// use ylong_io::TcpStream;
133     ///
134     /// let addr = "127.0.0.1:1234".parse().unwrap();
135     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
136     /// stream.set_nodelay(true).expect("set_nodelay call failed");
137     /// ```
set_nodelay(&self, nodelay: bool) -> io::Result<()>138     pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
139         self.inner.set_nodelay(nodelay)
140     }
141 
142     /// Gets the value of the `TCP_NODELAY`.
143     ///
144     /// # Examples
145     ///
146     /// ```no_run
147     /// use ylong_io::TcpStream;
148     ///
149     /// let addr = "127.0.0.1:1234".parse().unwrap();
150     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
151     /// stream.set_nodelay(true).expect("set_nodelay call failed");
152     /// assert_eq!(stream.nodelay().unwrap_or(false), true);
153     /// ```
nodelay(&self) -> io::Result<bool>154     pub fn nodelay(&self) -> io::Result<bool> {
155         self.inner.nodelay()
156     }
157 
158     /// Gets the value of the linger on this socket by getting `SO_LINGER`
159     /// option.
160     ///
161     /// # Examples
162     ///
163     /// ```no_run
164     /// use ylong_io::TcpStream;
165     ///
166     /// let addr = "127.0.0.1:1234".parse().unwrap();
167     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
168     /// println!("{:?}", stream.linger());
169     /// ```
linger(&self) -> io::Result<Option<Duration>>170     pub fn linger(&self) -> io::Result<Option<Duration>> {
171         get_sock_linger(self.as_raw_socket())
172     }
173 
174     /// Sets the value of the linger on this socket by setting `SO_LINGER`
175     /// option.
176     ///
177     /// This value controls how the socket close when a stream has unsent data.
178     /// If SO_LINGER is set, the socket will still open for the duration as
179     /// the system attempts to send pending data. Otherwise, the system may
180     /// close the socket immediately, or wait for a default timeout.
181     ///
182     /// # Examples
183     ///
184     /// ```no_run
185     /// use ylong_io::TcpStream;
186     ///
187     /// let addr = "127.0.0.1:1234".parse().unwrap();
188     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
189     /// stream.set_linger(None).expect("Sets linger fail.");
190     /// ```
set_linger(&self, linger: Option<Duration>) -> io::Result<()>191     pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> {
192         set_sock_linger(self.as_raw_socket(), linger)
193     }
194 
195     /// Sets the value for the `IP_TTL`.
196     ///
197     /// # Examples
198     ///
199     /// ```no_run
200     /// use ylong_io::TcpStream;
201     ///
202     /// let addr = "127.0.0.1:1234".parse().unwrap();
203     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
204     /// stream.set_ttl(100).expect("set_ttl call failed");
205     /// ```
set_ttl(&self, ttl: u32) -> io::Result<()>206     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
207         self.inner.set_ttl(ttl)
208     }
209 
210     /// Gets the value of the `IP_TTL`.
211     ///
212     /// # Examples
213     ///
214     /// ```no_run
215     /// use ylong_io::TcpStream;
216     ///
217     /// let addr = "127.0.0.1:1234".parse().unwrap();
218     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
219     /// stream.set_ttl(100).expect("set_ttl call failed");
220     /// assert_eq!(stream.ttl().unwrap_or(0), 100);
221     /// ```
ttl(&self) -> io::Result<u32>222     pub fn ttl(&self) -> io::Result<u32> {
223         self.inner.ttl()
224     }
225 
226     /// Get the value of the `SO_ERROR`.
227     ///
228     /// # Examples
229     ///
230     /// ```no_run
231     /// use ylong_io::TcpStream;
232     ///
233     /// let addr = "127.0.0.1:1234".parse().unwrap();
234     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
235     /// stream.take_error().expect("No error was expected...");
236     /// ```
take_error(&self) -> io::Result<Option<io::Error>>237     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
238         self.inner.take_error()
239     }
240 
241     /// Same as std::net::TcpStream::peek().
242     ///
243     /// Receives data on the socket from the remote address to which it is
244     /// connected, without removing that data from the queue. On success,
245     /// returns the number of bytes peeked.
246     ///
247     /// # Examples
248     ///
249     /// ```no_run
250     /// use ylong_io::TcpStream;
251     ///
252     /// let addr = "127.0.0.1:1234".parse().unwrap();
253     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
254     /// let mut buf = [0; 10];
255     /// let len = stream.peek(&mut buf).expect("peek failed");
256     /// ```
peek(&self, buf: &mut [u8]) -> io::Result<usize>257     pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
258         self.inner.peek(buf)
259     }
260 }
261 
262 impl fmt::Debug for TcpStream {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result263     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
264         self.inner.fmt(f)
265     }
266 }
267 
268 macro_rules! read_write {
269     ($($identifier:tt)*) => {
270         impl Read for $($identifier)* {
271             fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
272                 self.state.try_io(|mut inner| inner.read(buf), &self.inner)
273             }
274 
275             fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
276                 self.state.try_io(|mut inner| inner.read_vectored(bufs), &self.inner)
277             }
278         }
279 
280         impl Write for $($identifier)* {
281             fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
282                 self.state.try_io(|mut inner| inner.write(buf), &self.inner)
283             }
284 
285             fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
286                 self.state.try_io(|mut inner| inner.write_vectored(bufs), &self.inner)
287             }
288 
289             fn flush(&mut self) -> io::Result<()> {
290                 self.state.try_io(|mut inner| inner.flush(), &self.inner)
291             }
292         }
293     };
294 }
295 
296 read_write!(TcpStream);
297 
298 read_write!(&TcpStream);
299 
300 impl IntoRawSocket for TcpStream {
into_raw_socket(self) -> RawSocket301     fn into_raw_socket(self) -> RawSocket {
302         self.inner.into_raw_socket()
303     }
304 }
305 
306 impl AsRawSocket for TcpStream {
as_raw_socket(&self) -> RawSocket307     fn as_raw_socket(&self) -> RawSocket {
308         self.inner.as_raw_socket()
309     }
310 }
311 
312 impl FromRawSocket for TcpStream {
313     /// Converts a `RawSocket` to a `TcpStream`.
from_raw_socket(socket: RawSocket) -> Self314     unsafe fn from_raw_socket(socket: RawSocket) -> Self {
315         TcpStream::from_std(FromRawSocket::from_raw_socket(socket))
316     }
317 }
318 
319 impl Source for TcpStream {
register( &mut self, selector: &Selector, token: Token, interests: Interest, ) -> io::Result<()>320     fn register(
321         &mut self,
322         selector: &Selector,
323         token: Token,
324         interests: Interest,
325     ) -> io::Result<()> {
326         self.state
327             .register(selector, token, interests, self.as_raw_socket())
328     }
329 
deregister(&mut self, _selector: &Selector) -> io::Result<()>330     fn deregister(&mut self, _selector: &Selector) -> io::Result<()> {
331         self.state.deregister()
332     }
333 
get_fd(&self) -> Fd334     fn get_fd(&self) -> Fd {
335         self.inner.as_raw_socket()
336     }
337 }
338