1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::io;
15 use std::mem::{self, size_of, MaybeUninit};
16 use std::net::{self, SocketAddr};
17 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
18 use std::time::Duration;
19 
20 use libc::{
21     c_int, c_void, linger, socklen_t, AF_INET, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_LINGER,
22     SO_REUSEADDR,
23 };
24 
25 use super::super::socket_addr::socket_addr_trans;
26 use super::TcpStream;
27 use crate::source::Fd;
28 use crate::sys::unix::socket::socket_new;
29 
30 pub(crate) struct TcpSocket {
31     socket: c_int,
32 }
33 
34 impl TcpSocket {
new_socket(addr: SocketAddr) -> io::Result<TcpSocket>35     pub(crate) fn new_socket(addr: SocketAddr) -> io::Result<TcpSocket> {
36         if addr.is_ipv4() {
37             TcpSocket::create_socket(AF_INET, SOCK_STREAM)
38         } else {
39             TcpSocket::create_socket(AF_INET6, SOCK_STREAM)
40         }
41     }
42 
create_socket(domain: c_int, socket_type: c_int) -> io::Result<TcpSocket>43     pub(crate) fn create_socket(domain: c_int, socket_type: c_int) -> io::Result<TcpSocket> {
44         let socket = socket_new(domain, socket_type)?;
45         Ok(TcpSocket {
46             socket: socket as c_int,
47         })
48     }
49 
set_reuse(&self, is_reuse: bool) -> io::Result<()>50     pub(crate) fn set_reuse(&self, is_reuse: bool) -> io::Result<()> {
51         let set_value: c_int = i32::from(is_reuse);
52 
53         match syscall!(setsockopt(
54             self.socket,
55             SOL_SOCKET,
56             SO_REUSEADDR,
57             (&set_value as *const c_int).cast::<c_void>(),
58             size_of::<c_int>() as socklen_t
59         )) {
60             Err(err) => Err(err),
61             Ok(_) => Ok(()),
62         }
63     }
64 
bind(&self, addr: SocketAddr) -> io::Result<()>65     pub(crate) fn bind(&self, addr: SocketAddr) -> io::Result<()> {
66         let (raw_addr, addr_length) = socket_addr_trans(&addr);
67         match syscall!(bind(self.socket, raw_addr.as_ptr(), addr_length)) {
68             Err(err) => Err(err),
69             Ok(_) => Ok(()),
70         }
71     }
72 
listen(self, max_connect: c_int) -> io::Result<()>73     pub(crate) fn listen(self, max_connect: c_int) -> io::Result<()> {
74         syscall!(listen(self.socket, max_connect))?;
75         Ok(())
76     }
77 
connect(self, addr: SocketAddr) -> io::Result<TcpStream>78     pub(crate) fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
79         let stream = TcpStream {
80             inner: unsafe { net::TcpStream::from_raw_fd(self.socket) },
81         };
82         let (raw_addr, addr_length) = socket_addr_trans(&addr);
83         match syscall!(connect(self.socket, raw_addr.as_ptr(), addr_length)) {
84             Err(err) if err.raw_os_error() != Some(libc::EINPROGRESS) => Err(err),
85             _ => Ok(stream),
86         }
87     }
88 }
89 
90 impl AsRawFd for TcpSocket {
as_raw_fd(&self) -> RawFd91     fn as_raw_fd(&self) -> RawFd {
92         self.socket
93     }
94 }
95 
96 impl FromRawFd for TcpSocket {
from_raw_fd(fd: RawFd) -> TcpSocket97     unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket {
98         TcpSocket { socket: fd }
99     }
100 }
101 
get_sock_linger(fd: Fd) -> io::Result<Option<Duration>>102 pub(crate) fn get_sock_linger(fd: Fd) -> io::Result<Option<Duration>> {
103     let mut payload: MaybeUninit<linger> = MaybeUninit::uninit();
104     let mut len = mem::size_of::<linger>() as libc::socklen_t;
105 
106     syscall!(getsockopt(
107         fd as c_int,
108         SOL_SOCKET,
109         SO_LINGER,
110         payload.as_mut_ptr().cast(),
111         &mut len,
112     ))
113     .map(|_| {
114         let linger = unsafe { payload.assume_init() };
115         from_linger(linger)
116     })
117 }
118 
set_sock_linger(fd: Fd, duration: Option<Duration>) -> io::Result<()>119 pub(crate) fn set_sock_linger(fd: Fd, duration: Option<Duration>) -> io::Result<()> {
120     let payload = into_linger(duration);
121     syscall!(setsockopt(
122         fd as c_int,
123         SOL_SOCKET,
124         SO_LINGER,
125         (&payload as *const linger).cast::<c_void>(),
126         mem::size_of::<linger>() as libc::socklen_t,
127     ))
128     .map(|_| ())
129 }
130 
from_linger(linger: linger) -> Option<Duration>131 fn from_linger(linger: linger) -> Option<Duration> {
132     if linger.l_onoff == 0 {
133         None
134     } else {
135         Some(Duration::from_secs(linger.l_linger as u64))
136     }
137 }
138 
into_linger(duration: Option<Duration>) -> linger139 fn into_linger(duration: Option<Duration>) -> linger {
140     match duration {
141         None => linger {
142             l_onoff: 0,
143             l_linger: 0,
144         },
145         Some(dur) => linger {
146             l_onoff: 1,
147             l_linger: dur.as_secs() as _,
148         },
149     }
150 }
151 
152 #[cfg(test)]
153 mod test {
154     use std::ffi::c_int;
155     use std::os::fd::{AsRawFd, FromRawFd};
156 
157     use crate::sys::unix::tcp::socket::into_linger;
158     use crate::sys::unix::tcp::TcpSocket;
159 
160     /// UT for `into_linger`
161     ///
162     /// # Brief
163     /// 1. Call `into_linger` with parameter None
164     /// 2. Check if the returned linger is correct
165     #[test]
ut_into_linger_none()166     fn ut_into_linger_none() {
167         let linger = into_linger(None);
168         assert_eq!(linger.l_linger, 0);
169         assert_eq!(linger.l_linger, 0);
170     }
171 }
172