1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::io;
15 use std::mem::{size_of, MaybeUninit};
16 use std::net::{self, SocketAddr};
17 use std::os::unix::io::{AsRawFd, FromRawFd};
18 
19 use libc::{c_int, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t};
20 
21 use super::{TcpSocket, TcpStream};
22 use crate::source::Fd;
23 #[cfg(target_os = "macos")]
24 use crate::sys::socket::set_non_block;
25 use crate::{Interest, Selector, Source, Token};
26 
27 /// A socket server.
28 pub struct TcpListener {
29     pub(crate) inner: net::TcpListener,
30 }
31 
32 impl TcpListener {
33     /// Binds a new tcp Listener to the specific address to receive connection
34     /// requests.
35     ///
36     /// The socket will be set to `SO_REUSEADDR`.
bind(addr: SocketAddr) -> io::Result<TcpListener>37     pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
38         let socket = TcpSocket::new_socket(addr)?;
39         let listener = TcpListener {
40             inner: unsafe { net::TcpListener::from_raw_fd(socket.as_raw_fd()) },
41         };
42         socket.set_reuse(true)?;
43         socket.bind(addr)?;
44         socket.listen(1024)?;
45         Ok(listener)
46     }
47 
48     /// Accepts connections and returns the `TcpStream` and the remote peer
49     /// address.
50     ///
51     /// # Error
52     /// This may return an `Err(e)` where `e.kind()` is
53     /// `io::ErrorKind::WouldBlock`. This means a stream may be ready at a
54     /// later point and one should wait for an event before calling `accept`
55     /// again.
accept(&self) -> io::Result<(TcpStream, SocketAddr)>56     pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
57         let mut addr: MaybeUninit<sockaddr_storage> = MaybeUninit::uninit();
58         let mut length = size_of::<sockaddr_storage>() as socklen_t;
59 
60         #[cfg(target_os = "linux")]
61         let stream = match syscall!(accept4(
62             self.inner.as_raw_fd(),
63             addr.as_mut_ptr().cast::<_>(),
64             &mut length,
65             libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
66         )) {
67             Ok(socket) => unsafe { net::TcpStream::from_raw_fd(socket) },
68             Err(err) => {
69                 return Err(err);
70             }
71         };
72 
73         #[cfg(target_os = "macos")]
74         let stream = match syscall!(accept(
75             self.inner.as_raw_fd(),
76             addr.as_mut_ptr() as *mut _,
77             &mut length
78         )) {
79             Ok(socket) => {
80                 set_non_block(socket)?;
81                 unsafe { net::TcpStream::from_raw_fd(socket) }
82             }
83             Err(e) => return Err(e),
84         };
85 
86         let ret = unsafe { trans_addr_2_socket(addr.as_ptr()) };
87 
88         match ret {
89             Ok(sockaddr) => Ok((TcpStream::from_std(stream), sockaddr)),
90             Err(err) => Err(err),
91         }
92     }
93 
94     /// Returns the local socket address of this listener.
95     ///
96     /// # Examples
97     ///
98     /// ```no_run
99     /// use ylong_io::TcpListener;
100     ///
101     /// let addr = "127.0.0.1:1234".parse().unwrap();
102     /// let mut server = TcpListener::bind(addr).unwrap();
103     /// let ret = server.local_addr().unwrap();
104     /// ```
local_addr(&self) -> io::Result<SocketAddr>105     pub fn local_addr(&self) -> io::Result<SocketAddr> {
106         self.inner.local_addr()
107     }
108 
109     /// Gets the value of the IP_TTL option for this socket.
110     ///
111     /// # Examples
112     ///
113     /// ```no_run
114     /// use ylong_io::TcpListener;
115     ///
116     /// let addr = "127.0.0.1:1234".parse().unwrap();
117     /// let mut server = TcpListener::bind(addr).unwrap();
118     /// let ret = server.ttl().unwrap();
119     /// ```
ttl(&self) -> io::Result<u32>120     pub fn ttl(&self) -> io::Result<u32> {
121         self.inner.ttl()
122     }
123 
124     /// Sets the value for the IP_TTL option on this socket.
125     /// This value sets the time-to-live field that is used in every packet sent
126     /// from this socket.
127     ///
128     /// # Examples
129     ///
130     /// ```no_run
131     /// use ylong_io::TcpListener;
132     ///
133     /// let addr = "127.0.0.1:1234".parse().unwrap();
134     /// let mut server = TcpListener::bind(addr).unwrap();
135     /// let ret = server.set_ttl(100).unwrap();
136     /// ```
set_ttl(&self, ttl: u32) -> io::Result<()>137     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
138         self.inner.set_ttl(ttl)
139     }
140 
141     /// Gets the value of the SO_ERROR option on this socket.
142     /// This will retrieve the stored error in the underlying socket, clearing
143     /// the field in the process. This can be useful for checking errors between
144     /// calls.
145     ///
146     /// # Examples
147     ///
148     /// ```no_run
149     /// use ylong_io::TcpListener;
150     ///
151     /// let addr = "127.0.0.1:1234".parse().unwrap();
152     /// let mut server = TcpListener::bind(addr).unwrap();
153     /// let ret = server.take_error().unwrap();
154     /// ```
take_error(&self) -> io::Result<Option<io::Error>>155     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
156         self.inner.take_error()
157     }
158 }
159 
trans_addr_2_socket( storage: *const libc::sockaddr_storage, ) -> io::Result<SocketAddr>160 pub(crate) unsafe fn trans_addr_2_socket(
161     storage: *const libc::sockaddr_storage,
162 ) -> io::Result<SocketAddr> {
163     match (*storage).ss_family as c_int {
164         libc::AF_INET => Ok(SocketAddr::V4(*(storage.cast::<sockaddr_in>().cast::<_>()))),
165         libc::AF_INET6 => Ok(SocketAddr::V6(
166             *(storage.cast::<sockaddr_in6>().cast::<_>()),
167         )),
168         _ => {
169             let err = io::Error::new(io::ErrorKind::Other, "Cannot transfer address into socket.");
170             Err(err)
171         }
172     }
173 }
174 
175 impl Source for TcpListener {
register( &mut self, selector: &Selector, token: Token, interests: Interest, ) -> io::Result<()>176     fn register(
177         &mut self,
178         selector: &Selector,
179         token: Token,
180         interests: Interest,
181     ) -> io::Result<()> {
182         selector.register(self.get_fd(), token, interests)
183     }
184 
deregister(&mut self, selector: &Selector) -> io::Result<()>185     fn deregister(&mut self, selector: &Selector) -> io::Result<()> {
186         selector.deregister(self.get_fd())
187     }
188 
get_fd(&self) -> Fd189     fn get_fd(&self) -> Fd {
190         self.inner.as_raw_fd()
191     }
192 }
193 
194 impl std::fmt::Debug for TcpListener {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result195     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196         self.inner.fmt(f)
197     }
198 }
199