1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::collections::VecDeque;
15 use std::ffi::c_void;
16 use std::io;
17 use std::marker::PhantomPinned;
18 use std::mem::size_of;
19 use std::os::windows::io::RawSocket;
20 use std::pin::Pin;
21 use std::ptr::null_mut;
22 use std::sync::atomic::{AtomicBool, Ordering};
23 use std::sync::{Arc, Mutex};
24 use std::time::Duration;
25 
26 use crate::sys::winapi::{
27     WSAGetLastError, WSAIoctl, ERROR_INVALID_HANDLE, ERROR_IO_PENDING, HANDLE, OVERLAPPED,
28     SIO_BASE_HANDLE, SIO_BSP_HANDLE, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE_SELECT, SOCKET_ERROR,
29     STATUS_CANCELLED, WAIT_TIMEOUT,
30 };
31 use crate::sys::windows::afd;
32 use crate::sys::windows::afd::{Afd, AfdGroup, AfdPollInfo};
33 use crate::sys::windows::events::{
34     Events, ERROR_FLAGS, READABLE_FLAGS, READ_CLOSED_FLAGS, WRITABLE_FLAGS, WRITE_CLOSED_FLAGS,
35 };
36 use crate::sys::windows::io_status_block::IoStatusBlock;
37 use crate::sys::windows::iocp::{CompletionPort, CompletionStatus};
38 use crate::sys::NetInner;
39 use crate::{Event, Interest, Token};
40 
41 /// An wrapper for different OS polling system.
42 /// Linux: epoll
43 /// Windows: iocp
44 /// macos: kqueue
45 #[derive(Debug)]
46 pub struct Selector {
47     inner: Arc<SelectorInner>,
48 }
49 
50 impl Selector {
new() -> io::Result<Selector>51     pub(crate) fn new() -> io::Result<Selector> {
52         SelectorInner::new().map(|inner| Selector {
53             inner: Arc::new(inner),
54         })
55     }
56 
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>57     pub(crate) fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
58         self.inner.select(events, timeout)
59     }
60 
register( &self, socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>61     pub(crate) fn register(
62         &self,
63         socket: RawSocket,
64         token: Token,
65         interests: Interest,
66     ) -> io::Result<NetInner> {
67         SelectorInner::register(&self.inner, socket, token, interests)
68     }
69 
clone_cp(&self) -> Arc<CompletionPort>70     pub(crate) fn clone_cp(&self) -> Arc<CompletionPort> {
71         self.inner.completion_port.clone()
72     }
73 }
74 
75 #[derive(Debug)]
76 pub(crate) struct SelectorInner {
77     /// IOCP Handle.
78     completion_port: Arc<CompletionPort>,
79     /// Registered/re-registered IO events are placed in this queue.
80     update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
81     /// Afd Group.
82     afd_group: AfdGroup,
83     /// Weather the Selector is polling.
84     polling: AtomicBool,
85 }
86 
87 impl SelectorInner {
88     /// Creates a new SelectorInner
new() -> io::Result<SelectorInner>89     fn new() -> io::Result<SelectorInner> {
90         CompletionPort::new().map(|cp| {
91             let arc_cp = Arc::new(cp);
92             let cp_afd = Arc::clone(&arc_cp);
93 
94             SelectorInner {
95                 completion_port: arc_cp,
96                 update_queue: Mutex::new(VecDeque::new()),
97                 afd_group: AfdGroup::new(cp_afd),
98                 polling: AtomicBool::new(false),
99             }
100         })
101     }
102 
103     /// Start poll
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>104     fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
105         events.clear();
106 
107         match timeout {
108             None => loop {
109                 let len = self.select_inner(events, timeout)?;
110                 if len != 0 {
111                     return Ok(());
112                 }
113             },
114             Some(_) => {
115                 self.select_inner(events, timeout)?;
116                 Ok(())
117             }
118         }
119     }
120 
select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize>121     fn select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize> {
122         // We can only poll once at the same time.
123         assert!(
124             !self.polling.swap(true, Ordering::AcqRel),
125             "Can't be polling twice at the same time"
126         );
127 
128         unsafe { self.update_sockets_events() }?;
129 
130         let results = self
131             .completion_port
132             .get_results(&mut events.status, timeout);
133 
134         self.polling.store(false, Ordering::Relaxed);
135 
136         match results {
137             Ok(iocp_events) => Ok(unsafe { self.feed_events(&mut events.events, iocp_events) }),
138             Err(ref e) if e.raw_os_error() == Some(WAIT_TIMEOUT as i32) => Ok(0),
139             Err(e) => Err(e),
140         }
141     }
142 
143     /// Process completed operation and put them into events; regular AFD events
144     /// are put back into VecDeque
feed_events( &self, events: &mut Vec<Event>, iocp_events: &[CompletionStatus], ) -> usize145     unsafe fn feed_events(
146         &self,
147         events: &mut Vec<Event>,
148         iocp_events: &[CompletionStatus],
149     ) -> usize {
150         let mut epoll_event_count = 0;
151         let mut update_queue = self.update_queue.lock().unwrap();
152         for iocp_event in iocp_events.iter() {
153             if iocp_event.overlapped().is_null() {
154                 events.push(Event::from_completion_status(iocp_event));
155                 epoll_event_count += 1;
156                 continue;
157             } else if iocp_event.token() % 2 == 1 {
158                 // Non-AFD event, including pipe.
159                 let callback = (*(iocp_event.overlapped().cast::<super::Overlapped>())).callback;
160 
161                 let len = events.len();
162                 callback(iocp_event.entry(), Some(events));
163                 epoll_event_count += events.len() - len;
164                 continue;
165             }
166 
167             // General asynchronous IO event.
168             let sock_state = from_overlapped(iocp_event.overlapped());
169             let mut sock_guard = sock_state.lock().unwrap();
170             if let Some(event) = sock_guard.sock_feed_event() {
171                 events.push(event);
172                 epoll_event_count += 1;
173             }
174 
175             // Reregister the socket.
176             if !sock_guard.is_delete_pending() {
177                 update_queue.push_back(sock_state.clone());
178             }
179         }
180 
181         self.afd_group.release_unused_afd();
182         epoll_event_count
183     }
184 
185     /// Updates each SockState in the Deque, started only when Poll::poll() is
186     /// called externally
update_sockets_events(&self) -> io::Result<()>187     unsafe fn update_sockets_events(&self) -> io::Result<()> {
188         let mut update_queue = self.update_queue.lock().unwrap();
189         for sock in update_queue.iter_mut() {
190             let mut sock_internal = sock.lock().unwrap();
191             if !sock_internal.delete_pending {
192                 sock_internal.update(sock)?;
193             }
194         }
195         // Deletes events which has been updated successful.
196         update_queue.retain(|sock| sock.lock().unwrap().has_error());
197 
198         self.afd_group.release_unused_afd();
199         Ok(())
200     }
201 
202     /// No actual system call is made at register, it only starts at
203     /// Poll::poll(). Return Arc<NetInternal> and put it in the asynchronous
204     /// IO structure
register( this: &Arc<Self>, raw_socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>205     pub(crate) fn register(
206         this: &Arc<Self>,
207         raw_socket: RawSocket,
208         token: Token,
209         interests: Interest,
210     ) -> io::Result<NetInner> {
211         // Creates Afd
212         let afd = this.afd_group.acquire()?;
213         let mut sock_state = SockState::new(raw_socket, afd)?;
214 
215         let flags = interests_to_afd_flags(interests);
216         sock_state.set_event(flags, token.0 as u64);
217 
218         let pin_sock_state = Arc::pin(Mutex::new(sock_state));
219 
220         let net_internal = NetInner::new(this.clone(), token, interests, pin_sock_state.clone());
221 
222         // Adds SockState to VecDeque
223         this.queue_state(pin_sock_state);
224 
225         if this.polling.load(Ordering::Acquire) {
226             unsafe { this.update_sockets_events()? }
227         }
228 
229         Ok(net_internal)
230     }
231 
232     /// Re-register, put SockState back into VecDeque
reregister( &self, state: Pin<Arc<Mutex<SockState>>>, token: Token, interests: Interest, ) -> io::Result<()>233     pub(crate) fn reregister(
234         &self,
235         state: Pin<Arc<Mutex<SockState>>>,
236         token: Token,
237         interests: Interest,
238     ) -> io::Result<()> {
239         let flags = interests_to_afd_flags(interests);
240         state.lock().unwrap().set_event(flags, token.0 as u64);
241 
242         // Put back in the update queue VecDeque
243         self.queue_state(state);
244 
245         if self.polling.load(Ordering::Acquire) {
246             unsafe { self.update_sockets_events() }
247         } else {
248             Ok(())
249         }
250     }
251 
252     /// Adds SockState to VecDeque last.
queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>)253     fn queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>) {
254         let mut update_queue = self.update_queue.lock().unwrap();
255         update_queue.push_back(sock_state);
256     }
257 }
258 
259 impl Drop for SelectorInner {
drop(&mut self)260     fn drop(&mut self) {
261         loop {
262             let complete_num: usize;
263             let mut status: [CompletionStatus; 1024] = [CompletionStatus::zero(); 1024];
264 
265             let result = self
266                 .completion_port
267                 .get_results(&mut status, Some(Duration::from_millis(0)));
268 
269             match result {
270                 Ok(iocp_events) => {
271                     complete_num = iocp_events.iter().len();
272                     release_events(iocp_events);
273                 }
274 
275                 Err(_) => {
276                     break;
277                 }
278             }
279 
280             if complete_num == 0 {
281                 // continue looping until all completion status have been drained
282                 break;
283             }
284         }
285         self.afd_group.release_unused_afd();
286     }
287 }
288 
release_events(iocp_events: &mut [CompletionStatus])289 fn release_events(iocp_events: &mut [CompletionStatus]) {
290     for iocp_event in iocp_events.iter() {
291         if iocp_event.overlapped().is_null() {
292             // User event
293         } else if iocp_event.token() % 2 == 1 {
294             // For pipe, dispatch the event so it can release resources
295             let callback =
296                 unsafe { (*(iocp_event.overlapped().cast::<super::Overlapped>())).callback };
297 
298             callback(iocp_event.entry(), None);
299         } else {
300             // Release memory of Arc reference
301             let _ = from_overlapped(iocp_event.overlapped());
302         }
303     }
304 }
305 
306 #[derive(Debug, PartialEq)]
307 enum SockPollStatus {
308     /// Initial Value.
309     Idle,
310     /// System function called when updating sockets_events, set from Idle to
311     /// Pending. Update only when polling. Only the socket of Pending can be
312     /// cancelled.
313     Pending,
314     /// After calling the system api to cancel the sock, set it to Cancelled.
315     Cancelled,
316 }
317 
318 /// Saves all information of the socket during polling.
319 #[derive(Debug)]
320 pub struct SockState {
321     iosb: IoStatusBlock,
322     poll_info: AfdPollInfo,
323     /// The file handle to which request is bound.
324     afd: Arc<Afd>,
325     /// SOCKET of the request
326     base_socket: RawSocket,
327     /// User Token
328     user_token: u64,
329     /// user Interest
330     user_interests_flags: u32,
331     /// When this socket is polled, save user_interests_flags in
332     /// polling_interests_flags. Used for comparison during re-registration.
333     polling_interests_flags: u32,
334     /// Current Status. When this is Pending, System API calls must be made.
335     poll_status: SockPollStatus,
336     /// Mark if it is deleted.
337     delete_pending: bool,
338     /// Error during updating
339     error: Option<i32>,
340 
341     _pinned: PhantomPinned,
342 }
343 
344 impl SockState {
345     /// Creates a new SockState with RawSocket and Afd.
new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState>346     fn new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
347         Ok(SockState {
348             iosb: IoStatusBlock::zeroed(),
349             poll_info: AfdPollInfo::zeroed(),
350             afd,
351             base_socket: get_base_socket(socket)?,
352             user_interests_flags: 0,
353             polling_interests_flags: 0,
354             user_token: 0,
355             poll_status: SockPollStatus::Idle,
356             delete_pending: false,
357 
358             error: None,
359             _pinned: PhantomPinned,
360         })
361     }
362 
update_while_idle(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()>363     fn update_while_idle(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
364         // Init AfdPollInfo
365         self.poll_info.exclusive = 0;
366         self.poll_info.number_of_handles = 1;
367         self.poll_info.timeout = i64::MAX;
368         self.poll_info.handles[0].handle = self.base_socket as HANDLE;
369         self.poll_info.handles[0].status = 0;
370         self.poll_info.handles[0].events = self.user_interests_flags | afd::POLL_LOCAL_CLOSE;
371 
372         let overlapped_ptr = into_overlapped(self_arc.clone());
373 
374         // System call to run current event.
375         let result = unsafe {
376             self.afd
377                 .poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr)
378         };
379 
380         if let Err(e) = result {
381             // if an error happened, there must be an os error
382             let code = e.raw_os_error().unwrap();
383             if code != ERROR_IO_PENDING as i32 {
384                 drop(from_overlapped(overlapped_ptr.cast::<_>()));
385 
386                 return if code == ERROR_INVALID_HANDLE as i32 {
387                     // Socket closed; it'll be dropped.
388                     self.start_drop();
389                     Ok(())
390                 } else {
391                     self.error = e.raw_os_error();
392                     Err(e)
393                 };
394             }
395         };
396 
397         // The poll request was successfully submitted.
398         self.poll_status = SockPollStatus::Pending;
399         self.polling_interests_flags = self.user_interests_flags;
400         Ok(())
401     }
402 
update_while_pending(&mut self) -> io::Result<()>403     fn update_while_pending(&mut self) -> io::Result<()> {
404         if (self.user_interests_flags & afd::ALL_EVENTS & !self.polling_interests_flags) == 0 {
405             // All the events the user is interested in are already
406             // being monitored by the pending poll
407             // operation. It might spuriously complete because of an
408             // event that we're no longer interested in; when that
409             // happens we'll submit a new poll
410             // operation with the updated event mask.
411         } else {
412             // A poll operation is already pending, but it's not monitoring for all the
413             // events that the user is interested in. Therefore, cancel the pending
414             // poll operation; when we receive it's completion package, a new poll
415             // operation will be submitted with the correct event mask.
416             if let Err(e) = self.cancel() {
417                 self.error = e.raw_os_error();
418                 return Err(e);
419             }
420         }
421         Ok(())
422     }
423 
424     /// Update SockState in Deque, poll for each Afd.
update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()>425     fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
426         // delete_pending must false.
427         assert!(
428             !self.delete_pending,
429             "SockState update when delete_panding is true, {:#?}",
430             self
431         );
432 
433         // Make sure to reset previous error before a new update
434         self.error = None;
435 
436         match self.poll_status {
437             // Starts poll
438             SockPollStatus::Idle => self.update_while_idle(self_arc),
439             SockPollStatus::Pending => self.update_while_pending(),
440             // Do nothing
441             SockPollStatus::Cancelled => Ok(()),
442         }
443     }
444 
445     /// Returns true if user_interests_flags is inconsistent with
446     /// polling_interests_flags.
set_event(&mut self, flags: u32, token_data: u64) -> bool447     fn set_event(&mut self, flags: u32, token_data: u64) -> bool {
448         self.user_interests_flags = flags | afd::POLL_CONNECT_FAIL | afd::POLL_ABORT;
449         self.user_token = token_data;
450 
451         (self.user_interests_flags & !self.polling_interests_flags) != 0
452     }
453 
454     /// Process completed IO operation.
sock_feed_event(&mut self) -> Option<Event>455     fn sock_feed_event(&mut self) -> Option<Event> {
456         self.poll_status = SockPollStatus::Idle;
457         self.polling_interests_flags = 0;
458 
459         let mut afd_events = 0;
460         // Uses the status info in IO_STATUS_BLOCK to determine the socket poll status.
461         // It is unsafe to use a pointer of IO_STATUS_BLOCK.
462         unsafe {
463             if self.delete_pending {
464                 return None;
465             } else if self.iosb.Anonymous.Status == STATUS_CANCELLED {
466                 // The poll request was cancelled by CancelIoEx.
467             } else if self.iosb.Anonymous.Status < 0 {
468                 // The overlapped request itself failed in an unexpected way.
469                 afd_events = afd::POLL_CONNECT_FAIL;
470             } else if self.poll_info.number_of_handles < 1 {
471                 // This poll operation succeeded but didn't report any socket
472                 // events.
473             } else if self.poll_info.handles[0].events & afd::POLL_LOCAL_CLOSE != 0 {
474                 // The poll operation reported that the socket was closed.
475                 self.start_drop();
476                 return None;
477             } else {
478                 afd_events = self.poll_info.handles[0].events;
479             }
480         }
481         // Filter out events that the user didn't ask for.
482         afd_events &= self.user_interests_flags;
483 
484         if afd_events == 0 {
485             return None;
486         }
487 
488         // Simulates Edge-triggered behavior to match API usage.
489         // Intercept all read/write from user which may cause WouldBlock usage,
490         // And reregister the socket to reset the interests.
491         self.user_interests_flags &= !afd_events;
492 
493         Some(Event {
494             data: self.user_token,
495             flags: afd_events,
496         })
497     }
498 
499     /// Starts drop SockState
start_drop(&mut self)500     pub(crate) fn start_drop(&mut self) {
501         if !self.delete_pending {
502             // if it is Pending, it means SockState has been register in IOCP,
503             // must system call to cancel socket.
504             // else set delete_pending=true is enough.
505             if let SockPollStatus::Pending = self.poll_status {
506                 drop(self.cancel());
507             }
508             self.delete_pending = true;
509         }
510     }
511 
512     /// Only can cancel SockState of SockPollStatus::Pending, Set to
513     /// SockPollStatus::Cancelled.
cancel(&mut self) -> io::Result<()>514     fn cancel(&mut self) -> io::Result<()> {
515         // Checks poll_status again.
516         if self.poll_status != SockPollStatus::Pending {
517             unreachable!("Invalid poll status during cancel, {:#?}", self);
518         }
519 
520         unsafe {
521             self.afd.cancel(&mut *self.iosb)?;
522         }
523 
524         // Only here set SockPollStatus::Cancelled, SockStates has been system called to
525         // cancel
526         self.poll_status = SockPollStatus::Cancelled;
527         self.polling_interests_flags = 0;
528 
529         Ok(())
530     }
531 
is_delete_pending(&self) -> bool532     fn is_delete_pending(&self) -> bool {
533         self.delete_pending
534     }
535 
has_error(&self) -> bool536     fn has_error(&self) -> bool {
537         self.error.is_some()
538     }
539 }
540 
541 impl Drop for SockState {
drop(&mut self)542     fn drop(&mut self) {
543         self.start_drop();
544     }
545 }
546 
get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket>547 fn get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket> {
548     let res = base_socket_inner(raw_socket, SIO_BASE_HANDLE);
549     if let Ok(base_socket) = res {
550         return Ok(base_socket);
551     }
552 
553     for &ioctl in &[SIO_BSP_HANDLE_SELECT, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE] {
554         if let Ok(base_socket) = base_socket_inner(raw_socket, ioctl) {
555             if base_socket != raw_socket {
556                 return Ok(base_socket);
557             }
558         }
559     }
560 
561     // res is an error, then there must be an os error
562     Err(io::Error::from_raw_os_error(res.unwrap_err()))
563 }
564 
base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32>565 fn base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32> {
566     let mut base_socket: RawSocket = 0;
567     let mut bytes_returned: u32 = 0;
568     unsafe {
569         if WSAIoctl(
570             raw_socket as usize,
571             control_code,
572             null_mut(),
573             0,
574             (&mut base_socket as *mut RawSocket).cast::<c_void>(),
575             size_of::<RawSocket>() as u32,
576             &mut bytes_returned,
577             null_mut(),
578             None,
579         ) != SOCKET_ERROR
580         {
581             Ok(base_socket)
582         } else {
583             // Returns the error status for the last Windows Sockets operation that failed.
584             Err(WSAGetLastError())
585         }
586     }
587 }
588 
589 /// Interests convert to flags.
interests_to_afd_flags(interests: Interest) -> u32590 fn interests_to_afd_flags(interests: Interest) -> u32 {
591     let mut flags = 0;
592 
593     // Sets readable flags.
594     if interests.is_readable() {
595         flags |= READABLE_FLAGS | READ_CLOSED_FLAGS | ERROR_FLAGS;
596     }
597 
598     // Sets writable flags.
599     if interests.is_writable() {
600         flags |= WRITABLE_FLAGS | WRITE_CLOSED_FLAGS | ERROR_FLAGS;
601     }
602 
603     flags
604 }
605 
606 /// Converts the pointer to a `SockState` into a raw pointer.
into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void607 fn into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void {
608     let overlapped_ptr: *const Mutex<SockState> =
609         unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) };
610     overlapped_ptr as *mut _
611 }
612 
613 /// Convert a raw overlapped pointer into a reference to `SockState`.
from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>>614 fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>> {
615     let sock_ptr: *const Mutex<SockState> = ptr as *const _;
616     unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) }
617 }
618