1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::mem::MaybeUninit;
15 use std::net::SocketAddr;
16 use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
17 use std::time::Duration;
18 use std::{io, mem};
19 
20 use libc::{c_int, getsockopt};
21 
22 use crate::sys::winapi::{
23     closesocket, ioctlsocket, setsockopt, socket, ADDRESS_FAMILY, AF_INET, AF_INET6, FIONBIO,
24     INVALID_SOCKET, LINGER, SOCKET, SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_LINGER,
25 };
26 use crate::sys::windows::net::init;
27 use crate::sys::windows::socket_addr::socket_addr_trans;
28 
29 pub(crate) struct TcpSocket {
30     socket: SOCKET,
31 }
32 
33 impl TcpSocket {
34     /// Gets new socket
new_socket(addr: SocketAddr) -> io::Result<TcpSocket>35     pub(crate) fn new_socket(addr: SocketAddr) -> io::Result<TcpSocket> {
36         if addr.is_ipv4() {
37             Self::create_socket(AF_INET, SOCK_STREAM)
38         } else {
39             Self::create_socket(AF_INET6, SOCK_STREAM)
40         }
41     }
42 
create_socket(domain: ADDRESS_FAMILY, socket_type: u16) -> io::Result<TcpSocket>43     fn create_socket(domain: ADDRESS_FAMILY, socket_type: u16) -> io::Result<TcpSocket> {
44         init();
45 
46         let socket = socket_syscall!(
47             socket(domain as i32, socket_type as i32, 0),
48             PartialEq::eq,
49             INVALID_SOCKET
50         )?;
51 
52         match socket_syscall!(ioctlsocket(socket, FIONBIO, &mut 1), PartialEq::ne, 0) {
53             Err(err) => {
54                 let _ = unsafe { closesocket(socket) };
55                 Err(err)
56             }
57             Ok(_) => Ok(TcpSocket {
58                 socket: socket as SOCKET,
59             }),
60         }
61     }
62 
63     /// System call to bind Socket.
bind(&self, addr: SocketAddr) -> io::Result<()>64     pub(crate) fn bind(&self, addr: SocketAddr) -> io::Result<()> {
65         use crate::sys::winapi::bind;
66 
67         let (raw_addr, raw_addr_length) = socket_addr_trans(&addr);
68         socket_syscall!(
69             bind(self.socket as _, raw_addr.as_ptr(), raw_addr_length),
70             PartialEq::eq,
71             SOCKET_ERROR
72         )?;
73         Ok(())
74     }
75 
76     /// System call to listen.
listen(self, backlog: u32) -> io::Result<()>77     pub(crate) fn listen(self, backlog: u32) -> io::Result<()> {
78         use std::convert::TryInto;
79 
80         use crate::sys::winapi::listen;
81 
82         let backlog = backlog.try_into().unwrap_or(i32::MAX);
83         socket_syscall!(
84             listen(self.socket as _, backlog),
85             PartialEq::eq,
86             SOCKET_ERROR
87         )?;
88         Ok(())
89     }
90 
91     /// System call to connect.
connect(self, addr: SocketAddr) -> io::Result<()>92     pub(crate) fn connect(self, addr: SocketAddr) -> io::Result<()> {
93         use crate::sys::winapi::connect;
94 
95         let (socket_addr, socket_addr_length) = socket_addr_trans(&addr);
96         let res = socket_syscall!(
97             connect(self.socket as _, socket_addr.as_ptr(), socket_addr_length),
98             PartialEq::eq,
99             SOCKET_ERROR
100         );
101 
102         match res {
103             Err(e) if e.kind() != io::ErrorKind::WouldBlock => Err(e),
104             _ => Ok(()),
105         }
106     }
107 }
108 
109 impl AsRawSocket for TcpSocket {
as_raw_socket(&self) -> RawSocket110     fn as_raw_socket(&self) -> RawSocket {
111         self.socket as RawSocket
112     }
113 }
114 
115 impl FromRawSocket for TcpSocket {
from_raw_socket(sock: RawSocket) -> Self116     unsafe fn from_raw_socket(sock: RawSocket) -> Self {
117         TcpSocket {
118             socket: sock as SOCKET,
119         }
120     }
121 }
122 
get_sock_linger(socket: RawSocket) -> io::Result<Option<Duration>>123 pub(crate) fn get_sock_linger(socket: RawSocket) -> io::Result<Option<Duration>> {
124     let mut optval: MaybeUninit<LINGER> = MaybeUninit::uninit();
125     let mut optlen = mem::size_of::<LINGER>() as c_int;
126 
127     socket_syscall!(
128         getsockopt(
129             socket as SOCKET,
130             SOL_SOCKET as c_int,
131             SO_LINGER as c_int,
132             optval.as_mut_ptr().cast(),
133             &mut optlen,
134         ),
135         PartialEq::eq,
136         SOCKET_ERROR
137     )
138     .map(|_| {
139         let linger = unsafe { optval.assume_init() };
140         from_linger(linger)
141     })
142 }
143 
set_sock_linger(socket: RawSocket, linger: Option<Duration>) -> io::Result<()>144 pub(crate) fn set_sock_linger(socket: RawSocket, linger: Option<Duration>) -> io::Result<()> {
145     let optval = into_linger(linger);
146     socket_syscall!(
147         setsockopt(
148             socket as SOCKET,
149             SOL_SOCKET as c_int,
150             SO_LINGER as c_int,
151             (&optval as *const LINGER).cast(),
152             mem::size_of::<LINGER>() as c_int,
153         ),
154         PartialEq::eq,
155         SOCKET_ERROR
156     )
157     .map(|_| ())
158 }
159 
from_linger(linger: LINGER) -> Option<Duration>160 fn from_linger(linger: LINGER) -> Option<Duration> {
161     if linger.l_onoff == 0 {
162         None
163     } else {
164         Some(Duration::from_secs(linger.l_linger as u64))
165     }
166 }
167 
into_linger(linger: Option<Duration>) -> LINGER168 fn into_linger(linger: Option<Duration>) -> LINGER {
169     match linger {
170         None => LINGER {
171             l_onoff: 0,
172             l_linger: 0,
173         },
174         Some(dur) => LINGER {
175             l_onoff: 1,
176             l_linger: dur.as_secs() as _,
177         },
178     }
179 }
180