1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::io::{self, IoSlice, IoSliceMut, Read, Write};
15 use std::net::{self, Shutdown, SocketAddr};
16 use std::os::unix::io::AsRawFd;
17 use std::time::Duration;
18 
19 use super::TcpSocket;
20 use crate::source::Fd;
21 use crate::sys::unix::tcp::socket::{get_sock_linger, set_sock_linger};
22 use crate::{Interest, Selector, Source, Token};
23 
24 /// A non-blocking TCP Stream between a local socket and a remote socket.
25 pub struct TcpStream {
26     /// Raw TCP socket
27     pub inner: net::TcpStream,
28 }
29 
30 impl TcpStream {
31     /// Create a new TCP stream and issue a non-blocking connect to the
32     /// specified address.
connect(addr: SocketAddr) -> io::Result<TcpStream>33     pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
34         let socket = TcpSocket::new_socket(addr)?;
35         socket.connect(addr)
36     }
37 
38     /// Creates a new `TcpStream` from a standard `net::TcpStream`.
from_std(stream: net::TcpStream) -> TcpStream39     pub fn from_std(stream: net::TcpStream) -> TcpStream {
40         TcpStream { inner: stream }
41     }
42 
43     /// Clones the TcpStream.
try_clone(&self) -> io::Result<Self>44     pub fn try_clone(&self) -> io::Result<Self> {
45         Ok(TcpStream {
46             inner: self.inner.try_clone()?,
47         })
48     }
49 
50     /// Returns the socket address of the local half of this TCP connection.
51     ///
52     /// # Examples
53     ///
54     /// ```no_run
55     /// use std::net::{IpAddr, Ipv4Addr};
56     ///
57     /// use ylong_io::TcpStream;
58     ///
59     /// let addr = "127.0.0.1:1234".parse().unwrap();
60     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
61     /// assert_eq!(
62     ///     stream.local_addr().unwrap().ip(),
63     ///     IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
64     /// );
65     /// ```
local_addr(&self) -> io::Result<SocketAddr>66     pub fn local_addr(&self) -> io::Result<SocketAddr> {
67         self.inner.local_addr()
68     }
69 
70     /// Returns the socket address of the remote half of this TCP connection.
71     ///
72     /// # Examples
73     ///
74     /// ```no_run
75     /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
76     ///
77     /// use ylong_io::TcpStream;
78     ///
79     /// let addr = "127.0.0.1:1234".parse().unwrap();
80     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
81     /// assert_eq!(
82     ///     stream.peer_addr().unwrap(),
83     ///     SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1234))
84     /// );
85     /// ```
peer_addr(&self) -> io::Result<SocketAddr>86     pub fn peer_addr(&self) -> io::Result<SocketAddr> {
87         self.inner.peer_addr()
88     }
89 
90     /// Sets the value of the `TCP_NODELAY`.
91     ///
92     /// # Examples
93     ///
94     /// ```no_run
95     /// use ylong_io::TcpStream;
96     ///
97     /// let addr = "127.0.0.1:1234".parse().unwrap();
98     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
99     /// stream.set_nodelay(true).expect("set_nodelay call failed");
100     /// ```
set_nodelay(&self, nodelay: bool) -> io::Result<()>101     pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
102         self.inner.set_nodelay(nodelay)
103     }
104 
105     /// Gets the value of the `TCP_NODELAY`.
106     ///
107     /// # Examples
108     ///
109     /// ```no_run
110     /// use ylong_io::TcpStream;
111     ///
112     /// let addr = "127.0.0.1:1234".parse().unwrap();
113     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
114     /// stream.set_nodelay(true).expect("set_nodelay call failed");
115     /// assert_eq!(stream.nodelay().unwrap_or(false), true);
116     /// ```
nodelay(&self) -> io::Result<bool>117     pub fn nodelay(&self) -> io::Result<bool> {
118         self.inner.nodelay()
119     }
120 
121     /// Gets the value of the linger on this socket by getting `SO_LINGER`
122     /// option.
123     ///
124     /// # Examples
125     ///
126     /// ```no_run
127     /// use ylong_io::TcpStream;
128     ///
129     /// let addr = "127.0.0.1:1234".parse().unwrap();
130     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
131     /// println!("{:?}", stream.linger());
132     /// ```
linger(&self) -> io::Result<Option<Duration>>133     pub fn linger(&self) -> io::Result<Option<Duration>> {
134         get_sock_linger(self.get_fd())
135     }
136 
137     /// Sets the value of the linger on this socket by setting `SO_LINGER`
138     /// option.
139     ///
140     /// This value controls how the socket close when a stream has unsent data.
141     /// If SO_LINGER is set, the socket will still open for the duration as
142     /// the system attempts to send pending data. Otherwise, the system may
143     /// close the socket immediately, or wait for a default timeout.
144     ///
145     /// # Examples
146     ///
147     /// ```no_run
148     /// use ylong_io::TcpStream;
149     ///
150     /// let addr = "127.0.0.1:1234".parse().unwrap();
151     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
152     /// stream.set_linger(None).expect("Sets linger fail.");
153     /// ```
set_linger(&self, linger: Option<Duration>) -> io::Result<()>154     pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> {
155         set_sock_linger(self.get_fd(), linger)
156     }
157 
158     /// Sets the value for the `IP_TTL`.
159     ///
160     /// # Examples
161     ///
162     /// ```no_run
163     /// use ylong_io::TcpStream;
164     ///
165     /// let addr = "127.0.0.1:1234".parse().unwrap();
166     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
167     /// stream.set_ttl(100).expect("set_ttl call failed");
168     /// ```
set_ttl(&self, ttl: u32) -> io::Result<()>169     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
170         self.inner.set_ttl(ttl)
171     }
172 
173     /// Gets the value of the `IP_TTL`.
174     ///
175     /// # Examples
176     ///
177     /// ```no_run
178     /// use ylong_io::TcpStream;
179     ///
180     /// let addr = "127.0.0.1:1234".parse().unwrap();
181     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
182     /// stream.set_ttl(100).expect("set_ttl call failed");
183     /// assert_eq!(stream.ttl().unwrap_or(0), 100);
184     /// ```
ttl(&self) -> io::Result<u32>185     pub fn ttl(&self) -> io::Result<u32> {
186         self.inner.ttl()
187     }
188 
189     /// Gets the value of the `SO_ERROR` option on this socket.
take_error(&self) -> io::Result<Option<io::Error>>190     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
191         self.inner.take_error()
192     }
193 
194     /// Shuts down the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> io::Result<()>195     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
196         self.inner.shutdown(how)
197     }
198 
199     /// Same as std::net::TcpStream::peek().
200     ///
201     /// Receives data on the socket from the remote address to which it is
202     /// connected, without removing that data from the queue. On success,
203     /// returns the number of bytes peeked.
204     ///
205     /// # Examples
206     ///
207     /// ```no_run
208     /// use ylong_io::TcpStream;
209     ///
210     /// let addr = "127.0.0.1:1234".parse().unwrap();
211     /// let stream = TcpStream::connect(addr).expect("Couldn't connect to the server...");
212     /// let mut buf = [0; 10];
213     /// let len = stream.peek(&mut buf).expect("peek failed");
214     /// ```
peek(&self, buf: &mut [u8]) -> io::Result<usize>215     pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
216         self.inner.peek(buf)
217     }
218 }
219 
220 impl Read for TcpStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>221     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
222         self.inner.read(buf)
223     }
224 
read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize>225     fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
226         self.inner.read_vectored(bufs)
227     }
228 }
229 
230 impl Write for TcpStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>231     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
232         self.inner.write(buf)
233     }
234 
write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize>235     fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
236         self.inner.write_vectored(bufs)
237     }
238 
flush(&mut self) -> io::Result<()>239     fn flush(&mut self) -> io::Result<()> {
240         self.inner.flush()
241     }
242 }
243 
244 impl Read for &TcpStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>245     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
246         let mut inner = &self.inner;
247         inner.read(buf)
248     }
249 
read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize>250     fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
251         let mut inner = &self.inner;
252         inner.read_vectored(bufs)
253     }
254 }
255 
256 impl Write for &TcpStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>257     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
258         let mut inner = &self.inner;
259         inner.write(buf)
260     }
261 
write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize>262     fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
263         let mut inner = &self.inner;
264         inner.write_vectored(bufs)
265     }
266 
flush(&mut self) -> io::Result<()>267     fn flush(&mut self) -> io::Result<()> {
268         let mut inner = &self.inner;
269         inner.flush()
270     }
271 }
272 
273 impl std::fmt::Debug for TcpStream {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result274     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275         self.inner.fmt(f)
276     }
277 }
278 
279 impl Source for TcpStream {
register( &mut self, selector: &Selector, token: Token, interests: Interest, ) -> io::Result<()>280     fn register(
281         &mut self,
282         selector: &Selector,
283         token: Token,
284         interests: Interest,
285     ) -> io::Result<()> {
286         selector.register(self.get_fd(), token, interests)
287     }
288 
deregister(&mut self, selector: &Selector) -> io::Result<()>289     fn deregister(&mut self, selector: &Selector) -> io::Result<()> {
290         selector.deregister(self.get_fd())
291     }
292 
get_fd(&self) -> Fd293     fn get_fd(&self) -> Fd {
294         self.inner.as_raw_fd()
295     }
296 }
297