1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::io;
15 use std::mem::{size_of, MaybeUninit};
16 use std::net::{self, SocketAddr};
17 use std::os::unix::io::{AsRawFd, FromRawFd};
18
19 use libc::{c_int, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t};
20
21 use super::{TcpSocket, TcpStream};
22 use crate::source::Fd;
23 #[cfg(target_os = "macos")]
24 use crate::sys::socket::set_non_block;
25 use crate::{Interest, Selector, Source, Token};
26
27 /// A socket server.
28 pub struct TcpListener {
29 pub(crate) inner: net::TcpListener,
30 }
31
32 impl TcpListener {
33 /// Binds a new tcp Listener to the specific address to receive connection
34 /// requests.
35 ///
36 /// The socket will be set to `SO_REUSEADDR`.
bind(addr: SocketAddr) -> io::Result<TcpListener>37 pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
38 let socket = TcpSocket::new_socket(addr)?;
39 let listener = TcpListener {
40 inner: unsafe { net::TcpListener::from_raw_fd(socket.as_raw_fd()) },
41 };
42 socket.set_reuse(true)?;
43 socket.bind(addr)?;
44 socket.listen(1024)?;
45 Ok(listener)
46 }
47
48 /// Accepts connections and returns the `TcpStream` and the remote peer
49 /// address.
50 ///
51 /// # Error
52 /// This may return an `Err(e)` where `e.kind()` is
53 /// `io::ErrorKind::WouldBlock`. This means a stream may be ready at a
54 /// later point and one should wait for an event before calling `accept`
55 /// again.
accept(&self) -> io::Result<(TcpStream, SocketAddr)>56 pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
57 let mut addr: MaybeUninit<sockaddr_storage> = MaybeUninit::uninit();
58 let mut length = size_of::<sockaddr_storage>() as socklen_t;
59
60 #[cfg(target_os = "linux")]
61 let stream = match syscall!(accept4(
62 self.inner.as_raw_fd(),
63 addr.as_mut_ptr().cast::<_>(),
64 &mut length,
65 libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
66 )) {
67 Ok(socket) => unsafe { net::TcpStream::from_raw_fd(socket) },
68 Err(err) => {
69 return Err(err);
70 }
71 };
72
73 #[cfg(target_os = "macos")]
74 let stream = match syscall!(accept(
75 self.inner.as_raw_fd(),
76 addr.as_mut_ptr() as *mut _,
77 &mut length
78 )) {
79 Ok(socket) => {
80 set_non_block(socket)?;
81 unsafe { net::TcpStream::from_raw_fd(socket) }
82 }
83 Err(e) => return Err(e),
84 };
85
86 let ret = unsafe { trans_addr_2_socket(addr.as_ptr()) };
87
88 match ret {
89 Ok(sockaddr) => Ok((TcpStream::from_std(stream), sockaddr)),
90 Err(err) => Err(err),
91 }
92 }
93
94 /// Returns the local socket address of this listener.
95 ///
96 /// # Examples
97 ///
98 /// ```no_run
99 /// use ylong_io::TcpListener;
100 ///
101 /// let addr = "127.0.0.1:1234".parse().unwrap();
102 /// let mut server = TcpListener::bind(addr).unwrap();
103 /// let ret = server.local_addr().unwrap();
104 /// ```
local_addr(&self) -> io::Result<SocketAddr>105 pub fn local_addr(&self) -> io::Result<SocketAddr> {
106 self.inner.local_addr()
107 }
108
109 /// Gets the value of the IP_TTL option for this socket.
110 ///
111 /// # Examples
112 ///
113 /// ```no_run
114 /// use ylong_io::TcpListener;
115 ///
116 /// let addr = "127.0.0.1:1234".parse().unwrap();
117 /// let mut server = TcpListener::bind(addr).unwrap();
118 /// let ret = server.ttl().unwrap();
119 /// ```
ttl(&self) -> io::Result<u32>120 pub fn ttl(&self) -> io::Result<u32> {
121 self.inner.ttl()
122 }
123
124 /// Sets the value for the IP_TTL option on this socket.
125 /// This value sets the time-to-live field that is used in every packet sent
126 /// from this socket.
127 ///
128 /// # Examples
129 ///
130 /// ```no_run
131 /// use ylong_io::TcpListener;
132 ///
133 /// let addr = "127.0.0.1:1234".parse().unwrap();
134 /// let mut server = TcpListener::bind(addr).unwrap();
135 /// let ret = server.set_ttl(100).unwrap();
136 /// ```
set_ttl(&self, ttl: u32) -> io::Result<()>137 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
138 self.inner.set_ttl(ttl)
139 }
140
141 /// Gets the value of the SO_ERROR option on this socket.
142 /// This will retrieve the stored error in the underlying socket, clearing
143 /// the field in the process. This can be useful for checking errors between
144 /// calls.
145 ///
146 /// # Examples
147 ///
148 /// ```no_run
149 /// use ylong_io::TcpListener;
150 ///
151 /// let addr = "127.0.0.1:1234".parse().unwrap();
152 /// let mut server = TcpListener::bind(addr).unwrap();
153 /// let ret = server.take_error().unwrap();
154 /// ```
take_error(&self) -> io::Result<Option<io::Error>>155 pub fn take_error(&self) -> io::Result<Option<io::Error>> {
156 self.inner.take_error()
157 }
158 }
159
trans_addr_2_socket( storage: *const libc::sockaddr_storage, ) -> io::Result<SocketAddr>160 pub(crate) unsafe fn trans_addr_2_socket(
161 storage: *const libc::sockaddr_storage,
162 ) -> io::Result<SocketAddr> {
163 match (*storage).ss_family as c_int {
164 libc::AF_INET => Ok(SocketAddr::V4(*(storage.cast::<sockaddr_in>().cast::<_>()))),
165 libc::AF_INET6 => Ok(SocketAddr::V6(
166 *(storage.cast::<sockaddr_in6>().cast::<_>()),
167 )),
168 _ => {
169 let err = io::Error::new(io::ErrorKind::Other, "Cannot transfer address into socket.");
170 Err(err)
171 }
172 }
173 }
174
175 impl Source for TcpListener {
register( &mut self, selector: &Selector, token: Token, interests: Interest, ) -> io::Result<()>176 fn register(
177 &mut self,
178 selector: &Selector,
179 token: Token,
180 interests: Interest,
181 ) -> io::Result<()> {
182 selector.register(self.get_fd(), token, interests)
183 }
184
deregister(&mut self, selector: &Selector) -> io::Result<()>185 fn deregister(&mut self, selector: &Selector) -> io::Result<()> {
186 selector.deregister(self.get_fd())
187 }
188
get_fd(&self) -> Fd189 fn get_fd(&self) -> Fd {
190 self.inner.as_raw_fd()
191 }
192 }
193
194 impl std::fmt::Debug for TcpListener {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 self.inner.fmt(f)
197 }
198 }
199