1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::future::Future;
15 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
16 use std::pin::Pin;
17 use std::task::{Context, Poll};
18 use std::{io, mem, option, vec};
19 
20 use crate::spawn_blocking;
21 use crate::task::JoinHandle;
22 
each_addr<A: ToSocketAddrs, F, T>(addr: A, mut f: F) -> io::Result<T> where F: FnMut(SocketAddr) -> io::Result<T>,23 pub(crate) async fn each_addr<A: ToSocketAddrs, F, T>(addr: A, mut f: F) -> io::Result<T>
24 where
25     F: FnMut(SocketAddr) -> io::Result<T>,
26 {
27     let addrs = addr.to_socket_addrs().await?;
28 
29     let mut last_e = None;
30 
31     for addr in addrs {
32         match f(addr) {
33             Ok(res) => return Ok(res),
34             Err(e) => last_e = Some(e),
35         }
36     }
37 
38     Err(last_e.unwrap_or(io::Error::new(
39         io::ErrorKind::InvalidInput,
40         "addr could not resolve to any address",
41     )))
42 }
43 
44 /// Convert the type that implements the trait to [`SocketAddr`]
45 pub trait ToSocketAddrs {
46     /// Returned iterator of SocketAddr.
47     type Iter: Iterator<Item = SocketAddr>;
48 
49     /// Converts this object to an iterator of resolved `SocketAddr`s.
to_socket_addrs(&self) -> State<Self::Iter>50     fn to_socket_addrs(&self) -> State<Self::Iter>;
51 }
52 
53 /// Parsing process status, str and (&str, u16) types may be Block
54 pub enum State<I> {
55     Block(JoinHandle<io::Result<I>>),
56     Ready(io::Result<I>),
57     Done,
58 }
59 
60 impl<I: Iterator<Item = SocketAddr>> Future for State<I> {
61     type Output = io::Result<I>;
62 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>63     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64         let this = self.get_mut();
65 
66         match mem::replace(this, State::Done) {
67             State::Block(mut task) => {
68                 let poll = Pin::new(&mut task).poll(cx)?;
69                 if poll.is_pending() {
70                     *this = State::Block(task);
71                 }
72                 poll
73             }
74             State::Ready(res) => Poll::Ready(res),
75             State::Done => unreachable!("cannot poll a completed future"),
76         }
77     }
78 }
79 
80 impl<I> Unpin for State<I> {}
81 
82 impl ToSocketAddrs for SocketAddr {
83     type Iter = option::IntoIter<SocketAddr>;
84 
to_socket_addrs(&self) -> State<Self::Iter>85     fn to_socket_addrs(&self) -> State<Self::Iter> {
86         State::Ready(Ok(Some(*self).into_iter()))
87     }
88 }
89 
90 impl ToSocketAddrs for SocketAddrV4 {
91     type Iter = option::IntoIter<SocketAddr>;
92 
to_socket_addrs(&self) -> State<Self::Iter>93     fn to_socket_addrs(&self) -> State<Self::Iter> {
94         SocketAddr::V4(*self).to_socket_addrs()
95     }
96 }
97 
98 impl ToSocketAddrs for SocketAddrV6 {
99     type Iter = option::IntoIter<SocketAddr>;
100 
to_socket_addrs(&self) -> State<Self::Iter>101     fn to_socket_addrs(&self) -> State<Self::Iter> {
102         SocketAddr::V6(*self).to_socket_addrs()
103     }
104 }
105 
106 impl ToSocketAddrs for (IpAddr, u16) {
107     type Iter = option::IntoIter<SocketAddr>;
108 
to_socket_addrs(&self) -> State<Self::Iter>109     fn to_socket_addrs(&self) -> State<Self::Iter> {
110         let (ip, port) = *self;
111         match ip {
112             IpAddr::V4(ip_type) => (ip_type, port).to_socket_addrs(),
113             IpAddr::V6(ip_type) => (ip_type, port).to_socket_addrs(),
114         }
115     }
116 }
117 
118 impl ToSocketAddrs for (Ipv4Addr, u16) {
119     type Iter = option::IntoIter<SocketAddr>;
120 
to_socket_addrs(&self) -> State<Self::Iter>121     fn to_socket_addrs(&self) -> State<Self::Iter> {
122         let (ip, port) = *self;
123         SocketAddrV4::new(ip, port).to_socket_addrs()
124     }
125 }
126 
127 impl ToSocketAddrs for (Ipv6Addr, u16) {
128     type Iter = option::IntoIter<SocketAddr>;
129 
to_socket_addrs(&self) -> State<Self::Iter>130     fn to_socket_addrs(&self) -> State<Self::Iter> {
131         let (ip, port) = *self;
132         SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
133     }
134 }
135 
136 impl ToSocketAddrs for (&str, u16) {
137     type Iter = vec::IntoIter<SocketAddr>;
138 
to_socket_addrs(&self) -> State<Self::Iter>139     fn to_socket_addrs(&self) -> State<Self::Iter> {
140         let (host, port) = *self;
141 
142         if let Ok(addr) = host.parse::<Ipv4Addr>() {
143             let addr = SocketAddrV4::new(addr, port);
144             return State::Ready(Ok(vec![SocketAddr::V4(addr)].into_iter()));
145         }
146 
147         if let Ok(addr) = host.parse::<Ipv6Addr>() {
148             let addr = SocketAddrV6::new(addr, port, 0, 0);
149             return State::Ready(Ok(vec![SocketAddr::V6(addr)].into_iter()));
150         }
151 
152         let host = host.to_string();
153         let task = spawn_blocking(move || {
154             let addr = (host.as_str(), port);
155             std::net::ToSocketAddrs::to_socket_addrs(&addr)
156         });
157         State::Block(task)
158     }
159 }
160 
161 impl ToSocketAddrs for str {
162     type Iter = vec::IntoIter<SocketAddr>;
163 
to_socket_addrs(&self) -> State<Self::Iter>164     fn to_socket_addrs(&self) -> State<Self::Iter> {
165         if let Ok(addr) = self.parse() {
166             return State::Ready(Ok(vec![addr].into_iter()));
167         }
168 
169         let addr = self.to_string();
170         let task = spawn_blocking(move || {
171             let addr = addr.as_str();
172             std::net::ToSocketAddrs::to_socket_addrs(addr)
173         });
174         State::Block(task)
175     }
176 }
177 
178 impl<'a> ToSocketAddrs for &'a [SocketAddr] {
179     type Iter = std::iter::Cloned<std::slice::Iter<'a, SocketAddr>>;
180 
to_socket_addrs(&self) -> State<Self::Iter>181     fn to_socket_addrs(&self) -> State<Self::Iter> {
182         State::Ready(Ok(self.iter().cloned()))
183     }
184 }
185 
186 impl ToSocketAddrs for String {
187     type Iter = vec::IntoIter<SocketAddr>;
188 
to_socket_addrs(&self) -> State<Self::Iter>189     fn to_socket_addrs(&self) -> State<Self::Iter> {
190         (**self).to_socket_addrs()
191     }
192 }
193 
194 impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {
195     type Iter = T::Iter;
196 
to_socket_addrs(&self) -> State<Self::Iter>197     fn to_socket_addrs(&self) -> State<Self::Iter> {
198         (**self).to_socket_addrs()
199     }
200 }
201 
202 #[cfg(test)]
203 mod test {
204     use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
205 
206     use crate::net::ToSocketAddrs;
207 
208     /// UT test cases for `ToSocketAddrs` str.
209     ///
210     /// # Brief
211     /// 1. Create an address with str.
212     /// 2. Call `to_socket_addrs()` to convert str to `SocketAddr`.
213     /// 3. Check if the test results are correct.
214     #[test]
ut_to_socket_addrs_str()215     fn ut_to_socket_addrs_str() {
216         let addr_str = "127.0.0.1:8080";
217         crate::block_on(async {
218             let mut addrs_iter = addr_str.to_socket_addrs().await.unwrap();
219 
220             let expected_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
221             assert_eq!(Some(expected_addr), addrs_iter.next());
222             assert!(addrs_iter.next().is_none());
223         });
224     }
225 
226     /// UT test cases for `ToSocketAddrs` blocking.
227     ///
228     /// # Brief
229     /// 1. Create an address with "localhost".
230     /// 2. Call `to_socket_addrs()` to convert str to `SocketAddr`.
231     /// 3. Check if the test results are correct.
232     #[test]
ut_to_socket_addrs_blocking()233     fn ut_to_socket_addrs_blocking() {
234         let addr_str = "localhost:8080";
235         crate::block_on(async {
236             let addrs_vec = addr_str
237                 .to_socket_addrs()
238                 .await
239                 .unwrap()
240                 .collect::<Vec<SocketAddr>>();
241 
242             let expected_addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
243             let expected_addr2 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8080));
244             println!("{:?}", addrs_vec);
245             assert!(addrs_vec.contains(&expected_addr1) || addrs_vec.contains(&expected_addr2));
246         });
247     }
248 
249     /// UT test cases for `ToSocketAddrs` (&str, u16).
250     ///
251     /// # Brief
252     /// 1. Create an address with (&str, u16).
253     /// 2. Call `to_socket_addrs()` to convert str to `SocketAddr`.
254     /// 3. Check if the test results are correct.
255     #[test]
ut_to_socket_addrs_str_u16()256     fn ut_to_socket_addrs_str_u16() {
257         let addr_str = ("127.0.0.1", 8080);
258         crate::block_on(async {
259             let mut addrs_iter = addr_str.to_socket_addrs().await.unwrap();
260 
261             let expected_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
262             assert_eq!(Some(expected_addr), addrs_iter.next());
263             assert!(addrs_iter.next().is_none());
264         });
265 
266         let addr_str = ("localhost", 8080);
267         crate::block_on(async {
268             let addrs_vec = addr_str
269                 .to_socket_addrs()
270                 .await
271                 .unwrap()
272                 .collect::<Vec<SocketAddr>>();
273 
274             let expected_addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
275             let expected_addr2 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8080));
276             assert!(addrs_vec.contains(&expected_addr1) || addrs_vec.contains(&expected_addr2));
277         });
278 
279         let addr_str = ("::1", 8080);
280         crate::block_on(async {
281             let mut addrs_iter = addr_str.to_socket_addrs().await.unwrap();
282 
283             let expected_addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8080));
284             assert_eq!(Some(expected_addr), addrs_iter.next());
285             assert!(addrs_iter.next().is_none());
286         });
287     }
288 
289     /// UT test cases for `ToSocketAddrs` (ipaddr, u16).
290     ///
291     /// # Brief
292     /// 1. Create an address with (ipaddr, u16).
293     /// 2. Call `to_socket_addrs()` to convert str to `SocketAddr`.
294     /// 3. Check if the test results are correct.
295     #[test]
ut_to_socket_addrs_ipaddr_u16()296     fn ut_to_socket_addrs_ipaddr_u16() {
297         let addr_str = (IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
298         crate::block_on(async {
299             let mut addrs_iter = addr_str.to_socket_addrs().await.unwrap();
300 
301             let expected_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
302             assert_eq!(Some(expected_addr), addrs_iter.next());
303             assert!(addrs_iter.next().is_none());
304         });
305 
306         let addr_str = (IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080);
307         crate::block_on(async {
308             let mut addrs_iter = addr_str.to_socket_addrs().await.unwrap();
309 
310             let expected_addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8080));
311             assert_eq!(Some(expected_addr), addrs_iter.next());
312             assert!(addrs_iter.next().is_none());
313         });
314     }
315 
316     /// UT test cases for `ToSocketAddrs` (ipv4addr, u16).
317     ///
318     /// # Brief
319     /// 1. Create an address with (ipv4addr, u16).
320     /// 2. Call `to_socket_addrs()` to convert str to `SocketAddr`.
321     /// 3. Check if the test results are correct.
322     #[test]
ut_to_socket_addrs_ipv4addr_u16()323     fn ut_to_socket_addrs_ipv4addr_u16() {
324         let addr_str = (Ipv4Addr::new(127, 0, 0, 1), 8080);
325         crate::block_on(async {
326             let mut addrs_iter = addr_str.to_socket_addrs().await.unwrap();
327 
328             let expected_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
329             assert_eq!(Some(expected_addr), addrs_iter.next());
330             assert!(addrs_iter.next().is_none());
331         });
332     }
333 
334     /// UT test cases for `ToSocketAddrs` (ipv6addr, u16).
335     ///
336     /// # Brief
337     /// 1. Create an address with (ipv6addr, u16).
338     /// 2. Call `to_socket_addrs()` to convert str to `SocketAddr`.
339     /// 3. Check if the test results are correct.
340     #[test]
ut_to_socket_addrs_ipv6addr_u16()341     fn ut_to_socket_addrs_ipv6addr_u16() {
342         let addr_str = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 8080);
343         crate::block_on(async {
344             let mut addrs_iter = addr_str.to_socket_addrs().await.unwrap();
345 
346             let expected_addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8080));
347             assert_eq!(Some(expected_addr), addrs_iter.next());
348             assert!(addrs_iter.next().is_none());
349         });
350     }
351 }
352