1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::collections::VecDeque;
15 use std::ffi::c_void;
16 use std::io;
17 use std::marker::PhantomPinned;
18 use std::mem::size_of;
19 use std::os::windows::io::RawSocket;
20 use std::pin::Pin;
21 use std::ptr::null_mut;
22 use std::sync::atomic::{AtomicBool, Ordering};
23 use std::sync::{Arc, Mutex};
24 use std::time::Duration;
25
26 use crate::sys::winapi::{
27 WSAGetLastError, WSAIoctl, ERROR_INVALID_HANDLE, ERROR_IO_PENDING, HANDLE, OVERLAPPED,
28 SIO_BASE_HANDLE, SIO_BSP_HANDLE, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE_SELECT, SOCKET_ERROR,
29 STATUS_CANCELLED, WAIT_TIMEOUT,
30 };
31 use crate::sys::windows::afd;
32 use crate::sys::windows::afd::{Afd, AfdGroup, AfdPollInfo};
33 use crate::sys::windows::events::{
34 Events, ERROR_FLAGS, READABLE_FLAGS, READ_CLOSED_FLAGS, WRITABLE_FLAGS, WRITE_CLOSED_FLAGS,
35 };
36 use crate::sys::windows::io_status_block::IoStatusBlock;
37 use crate::sys::windows::iocp::{CompletionPort, CompletionStatus};
38 use crate::sys::NetInner;
39 use crate::{Event, Interest, Token};
40
41 /// An wrapper for different OS polling system.
42 /// Linux: epoll
43 /// Windows: iocp
44 /// macos: kqueue
45 #[derive(Debug)]
46 pub struct Selector {
47 inner: Arc<SelectorInner>,
48 }
49
50 impl Selector {
new() -> io::Result<Selector>51 pub(crate) fn new() -> io::Result<Selector> {
52 SelectorInner::new().map(|inner| Selector {
53 inner: Arc::new(inner),
54 })
55 }
56
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>57 pub(crate) fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
58 self.inner.select(events, timeout)
59 }
60
register( &self, socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>61 pub(crate) fn register(
62 &self,
63 socket: RawSocket,
64 token: Token,
65 interests: Interest,
66 ) -> io::Result<NetInner> {
67 SelectorInner::register(&self.inner, socket, token, interests)
68 }
69
clone_cp(&self) -> Arc<CompletionPort>70 pub(crate) fn clone_cp(&self) -> Arc<CompletionPort> {
71 self.inner.completion_port.clone()
72 }
73 }
74
75 #[derive(Debug)]
76 pub(crate) struct SelectorInner {
77 /// IOCP Handle.
78 completion_port: Arc<CompletionPort>,
79 /// Registered/re-registered IO events are placed in this queue.
80 update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
81 /// Afd Group.
82 afd_group: AfdGroup,
83 /// Weather the Selector is polling.
84 polling: AtomicBool,
85 }
86
87 impl SelectorInner {
88 /// Creates a new SelectorInner
new() -> io::Result<SelectorInner>89 fn new() -> io::Result<SelectorInner> {
90 CompletionPort::new().map(|cp| {
91 let arc_cp = Arc::new(cp);
92 let cp_afd = Arc::clone(&arc_cp);
93
94 SelectorInner {
95 completion_port: arc_cp,
96 update_queue: Mutex::new(VecDeque::new()),
97 afd_group: AfdGroup::new(cp_afd),
98 polling: AtomicBool::new(false),
99 }
100 })
101 }
102
103 /// Start poll
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>104 fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
105 events.clear();
106
107 match timeout {
108 None => loop {
109 let len = self.select_inner(events, timeout)?;
110 if len != 0 {
111 return Ok(());
112 }
113 },
114 Some(_) => {
115 self.select_inner(events, timeout)?;
116 Ok(())
117 }
118 }
119 }
120
select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize>121 fn select_inner(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<usize> {
122 // We can only poll once at the same time.
123 assert!(
124 !self.polling.swap(true, Ordering::AcqRel),
125 "Can't be polling twice at the same time"
126 );
127
128 unsafe { self.update_sockets_events() }?;
129
130 let results = self
131 .completion_port
132 .get_results(&mut events.status, timeout);
133
134 self.polling.store(false, Ordering::Relaxed);
135
136 match results {
137 Ok(iocp_events) => Ok(unsafe { self.feed_events(&mut events.events, iocp_events) }),
138 Err(ref e) if e.raw_os_error() == Some(WAIT_TIMEOUT as i32) => Ok(0),
139 Err(e) => Err(e),
140 }
141 }
142
143 /// Process completed operation and put them into events; regular AFD events
144 /// are put back into VecDeque
feed_events( &self, events: &mut Vec<Event>, iocp_events: &[CompletionStatus], ) -> usize145 unsafe fn feed_events(
146 &self,
147 events: &mut Vec<Event>,
148 iocp_events: &[CompletionStatus],
149 ) -> usize {
150 let mut epoll_event_count = 0;
151 let mut update_queue = self.update_queue.lock().unwrap();
152 for iocp_event in iocp_events.iter() {
153 if iocp_event.overlapped().is_null() {
154 events.push(Event::from_completion_status(iocp_event));
155 epoll_event_count += 1;
156 continue;
157 } else if iocp_event.token() % 2 == 1 {
158 // Non-AFD event, including pipe.
159 let callback = (*(iocp_event.overlapped().cast::<super::Overlapped>())).callback;
160
161 let len = events.len();
162 callback(iocp_event.entry(), Some(events));
163 epoll_event_count += events.len() - len;
164 continue;
165 }
166
167 // General asynchronous IO event.
168 let sock_state = from_overlapped(iocp_event.overlapped());
169 let mut sock_guard = sock_state.lock().unwrap();
170 if let Some(event) = sock_guard.sock_feed_event() {
171 events.push(event);
172 epoll_event_count += 1;
173 }
174
175 // Reregister the socket.
176 if !sock_guard.is_delete_pending() {
177 update_queue.push_back(sock_state.clone());
178 }
179 }
180
181 self.afd_group.release_unused_afd();
182 epoll_event_count
183 }
184
185 /// Updates each SockState in the Deque, started only when Poll::poll() is
186 /// called externally
update_sockets_events(&self) -> io::Result<()>187 unsafe fn update_sockets_events(&self) -> io::Result<()> {
188 let mut update_queue = self.update_queue.lock().unwrap();
189 for sock in update_queue.iter_mut() {
190 let mut sock_internal = sock.lock().unwrap();
191 if !sock_internal.delete_pending {
192 sock_internal.update(sock)?;
193 }
194 }
195 // Deletes events which has been updated successful.
196 update_queue.retain(|sock| sock.lock().unwrap().has_error());
197
198 self.afd_group.release_unused_afd();
199 Ok(())
200 }
201
202 /// No actual system call is made at register, it only starts at
203 /// Poll::poll(). Return Arc<NetInternal> and put it in the asynchronous
204 /// IO structure
register( this: &Arc<Self>, raw_socket: RawSocket, token: Token, interests: Interest, ) -> io::Result<NetInner>205 pub(crate) fn register(
206 this: &Arc<Self>,
207 raw_socket: RawSocket,
208 token: Token,
209 interests: Interest,
210 ) -> io::Result<NetInner> {
211 // Creates Afd
212 let afd = this.afd_group.acquire()?;
213 let mut sock_state = SockState::new(raw_socket, afd)?;
214
215 let flags = interests_to_afd_flags(interests);
216 sock_state.set_event(flags, token.0 as u64);
217
218 let pin_sock_state = Arc::pin(Mutex::new(sock_state));
219
220 let net_internal = NetInner::new(this.clone(), token, interests, pin_sock_state.clone());
221
222 // Adds SockState to VecDeque
223 this.queue_state(pin_sock_state);
224
225 if this.polling.load(Ordering::Acquire) {
226 unsafe { this.update_sockets_events()? }
227 }
228
229 Ok(net_internal)
230 }
231
232 /// Re-register, put SockState back into VecDeque
reregister( &self, state: Pin<Arc<Mutex<SockState>>>, token: Token, interests: Interest, ) -> io::Result<()>233 pub(crate) fn reregister(
234 &self,
235 state: Pin<Arc<Mutex<SockState>>>,
236 token: Token,
237 interests: Interest,
238 ) -> io::Result<()> {
239 let flags = interests_to_afd_flags(interests);
240 state.lock().unwrap().set_event(flags, token.0 as u64);
241
242 // Put back in the update queue VecDeque
243 self.queue_state(state);
244
245 if self.polling.load(Ordering::Acquire) {
246 unsafe { self.update_sockets_events() }
247 } else {
248 Ok(())
249 }
250 }
251
252 /// Adds SockState to VecDeque last.
queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>)253 fn queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>) {
254 let mut update_queue = self.update_queue.lock().unwrap();
255 update_queue.push_back(sock_state);
256 }
257 }
258
259 impl Drop for SelectorInner {
drop(&mut self)260 fn drop(&mut self) {
261 loop {
262 let complete_num: usize;
263 let mut status: [CompletionStatus; 1024] = [CompletionStatus::zero(); 1024];
264
265 let result = self
266 .completion_port
267 .get_results(&mut status, Some(Duration::from_millis(0)));
268
269 match result {
270 Ok(iocp_events) => {
271 complete_num = iocp_events.iter().len();
272 release_events(iocp_events);
273 }
274
275 Err(_) => {
276 break;
277 }
278 }
279
280 if complete_num == 0 {
281 // continue looping until all completion status have been drained
282 break;
283 }
284 }
285 self.afd_group.release_unused_afd();
286 }
287 }
288
release_events(iocp_events: &mut [CompletionStatus])289 fn release_events(iocp_events: &mut [CompletionStatus]) {
290 for iocp_event in iocp_events.iter() {
291 if iocp_event.overlapped().is_null() {
292 // User event
293 } else if iocp_event.token() % 2 == 1 {
294 // For pipe, dispatch the event so it can release resources
295 let callback =
296 unsafe { (*(iocp_event.overlapped().cast::<super::Overlapped>())).callback };
297
298 callback(iocp_event.entry(), None);
299 } else {
300 // Release memory of Arc reference
301 let _ = from_overlapped(iocp_event.overlapped());
302 }
303 }
304 }
305
306 #[derive(Debug, PartialEq)]
307 enum SockPollStatus {
308 /// Initial Value.
309 Idle,
310 /// System function called when updating sockets_events, set from Idle to
311 /// Pending. Update only when polling. Only the socket of Pending can be
312 /// cancelled.
313 Pending,
314 /// After calling the system api to cancel the sock, set it to Cancelled.
315 Cancelled,
316 }
317
318 /// Saves all information of the socket during polling.
319 #[derive(Debug)]
320 pub struct SockState {
321 iosb: IoStatusBlock,
322 poll_info: AfdPollInfo,
323 /// The file handle to which request is bound.
324 afd: Arc<Afd>,
325 /// SOCKET of the request
326 base_socket: RawSocket,
327 /// User Token
328 user_token: u64,
329 /// user Interest
330 user_interests_flags: u32,
331 /// When this socket is polled, save user_interests_flags in
332 /// polling_interests_flags. Used for comparison during re-registration.
333 polling_interests_flags: u32,
334 /// Current Status. When this is Pending, System API calls must be made.
335 poll_status: SockPollStatus,
336 /// Mark if it is deleted.
337 delete_pending: bool,
338 /// Error during updating
339 error: Option<i32>,
340
341 _pinned: PhantomPinned,
342 }
343
344 impl SockState {
345 /// Creates a new SockState with RawSocket and Afd.
new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState>346 fn new(socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
347 Ok(SockState {
348 iosb: IoStatusBlock::zeroed(),
349 poll_info: AfdPollInfo::zeroed(),
350 afd,
351 base_socket: get_base_socket(socket)?,
352 user_interests_flags: 0,
353 polling_interests_flags: 0,
354 user_token: 0,
355 poll_status: SockPollStatus::Idle,
356 delete_pending: false,
357
358 error: None,
359 _pinned: PhantomPinned,
360 })
361 }
362
update_while_idle(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()>363 fn update_while_idle(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
364 // Init AfdPollInfo
365 self.poll_info.exclusive = 0;
366 self.poll_info.number_of_handles = 1;
367 self.poll_info.timeout = i64::MAX;
368 self.poll_info.handles[0].handle = self.base_socket as HANDLE;
369 self.poll_info.handles[0].status = 0;
370 self.poll_info.handles[0].events = self.user_interests_flags | afd::POLL_LOCAL_CLOSE;
371
372 let overlapped_ptr = into_overlapped(self_arc.clone());
373
374 // System call to run current event.
375 let result = unsafe {
376 self.afd
377 .poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr)
378 };
379
380 if let Err(e) = result {
381 // if an error happened, there must be an os error
382 let code = e.raw_os_error().unwrap();
383 if code != ERROR_IO_PENDING as i32 {
384 drop(from_overlapped(overlapped_ptr.cast::<_>()));
385
386 return if code == ERROR_INVALID_HANDLE as i32 {
387 // Socket closed; it'll be dropped.
388 self.start_drop();
389 Ok(())
390 } else {
391 self.error = e.raw_os_error();
392 Err(e)
393 };
394 }
395 };
396
397 // The poll request was successfully submitted.
398 self.poll_status = SockPollStatus::Pending;
399 self.polling_interests_flags = self.user_interests_flags;
400 Ok(())
401 }
402
update_while_pending(&mut self) -> io::Result<()>403 fn update_while_pending(&mut self) -> io::Result<()> {
404 if (self.user_interests_flags & afd::ALL_EVENTS & !self.polling_interests_flags) == 0 {
405 // All the events the user is interested in are already
406 // being monitored by the pending poll
407 // operation. It might spuriously complete because of an
408 // event that we're no longer interested in; when that
409 // happens we'll submit a new poll
410 // operation with the updated event mask.
411 } else {
412 // A poll operation is already pending, but it's not monitoring for all the
413 // events that the user is interested in. Therefore, cancel the pending
414 // poll operation; when we receive it's completion package, a new poll
415 // operation will be submitted with the correct event mask.
416 if let Err(e) = self.cancel() {
417 self.error = e.raw_os_error();
418 return Err(e);
419 }
420 }
421 Ok(())
422 }
423
424 /// Update SockState in Deque, poll for each Afd.
update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()>425 fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
426 // delete_pending must false.
427 assert!(
428 !self.delete_pending,
429 "SockState update when delete_panding is true, {:#?}",
430 self
431 );
432
433 // Make sure to reset previous error before a new update
434 self.error = None;
435
436 match self.poll_status {
437 // Starts poll
438 SockPollStatus::Idle => self.update_while_idle(self_arc),
439 SockPollStatus::Pending => self.update_while_pending(),
440 // Do nothing
441 SockPollStatus::Cancelled => Ok(()),
442 }
443 }
444
445 /// Returns true if user_interests_flags is inconsistent with
446 /// polling_interests_flags.
set_event(&mut self, flags: u32, token_data: u64) -> bool447 fn set_event(&mut self, flags: u32, token_data: u64) -> bool {
448 self.user_interests_flags = flags | afd::POLL_CONNECT_FAIL | afd::POLL_ABORT;
449 self.user_token = token_data;
450
451 (self.user_interests_flags & !self.polling_interests_flags) != 0
452 }
453
454 /// Process completed IO operation.
sock_feed_event(&mut self) -> Option<Event>455 fn sock_feed_event(&mut self) -> Option<Event> {
456 self.poll_status = SockPollStatus::Idle;
457 self.polling_interests_flags = 0;
458
459 let mut afd_events = 0;
460 // Uses the status info in IO_STATUS_BLOCK to determine the socket poll status.
461 // It is unsafe to use a pointer of IO_STATUS_BLOCK.
462 unsafe {
463 if self.delete_pending {
464 return None;
465 } else if self.iosb.Anonymous.Status == STATUS_CANCELLED {
466 // The poll request was cancelled by CancelIoEx.
467 } else if self.iosb.Anonymous.Status < 0 {
468 // The overlapped request itself failed in an unexpected way.
469 afd_events = afd::POLL_CONNECT_FAIL;
470 } else if self.poll_info.number_of_handles < 1 {
471 // This poll operation succeeded but didn't report any socket
472 // events.
473 } else if self.poll_info.handles[0].events & afd::POLL_LOCAL_CLOSE != 0 {
474 // The poll operation reported that the socket was closed.
475 self.start_drop();
476 return None;
477 } else {
478 afd_events = self.poll_info.handles[0].events;
479 }
480 }
481 // Filter out events that the user didn't ask for.
482 afd_events &= self.user_interests_flags;
483
484 if afd_events == 0 {
485 return None;
486 }
487
488 // Simulates Edge-triggered behavior to match API usage.
489 // Intercept all read/write from user which may cause WouldBlock usage,
490 // And reregister the socket to reset the interests.
491 self.user_interests_flags &= !afd_events;
492
493 Some(Event {
494 data: self.user_token,
495 flags: afd_events,
496 })
497 }
498
499 /// Starts drop SockState
start_drop(&mut self)500 pub(crate) fn start_drop(&mut self) {
501 if !self.delete_pending {
502 // if it is Pending, it means SockState has been register in IOCP,
503 // must system call to cancel socket.
504 // else set delete_pending=true is enough.
505 if let SockPollStatus::Pending = self.poll_status {
506 drop(self.cancel());
507 }
508 self.delete_pending = true;
509 }
510 }
511
512 /// Only can cancel SockState of SockPollStatus::Pending, Set to
513 /// SockPollStatus::Cancelled.
cancel(&mut self) -> io::Result<()>514 fn cancel(&mut self) -> io::Result<()> {
515 // Checks poll_status again.
516 if self.poll_status != SockPollStatus::Pending {
517 unreachable!("Invalid poll status during cancel, {:#?}", self);
518 }
519
520 unsafe {
521 self.afd.cancel(&mut *self.iosb)?;
522 }
523
524 // Only here set SockPollStatus::Cancelled, SockStates has been system called to
525 // cancel
526 self.poll_status = SockPollStatus::Cancelled;
527 self.polling_interests_flags = 0;
528
529 Ok(())
530 }
531
is_delete_pending(&self) -> bool532 fn is_delete_pending(&self) -> bool {
533 self.delete_pending
534 }
535
has_error(&self) -> bool536 fn has_error(&self) -> bool {
537 self.error.is_some()
538 }
539 }
540
541 impl Drop for SockState {
drop(&mut self)542 fn drop(&mut self) {
543 self.start_drop();
544 }
545 }
546
get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket>547 fn get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket> {
548 let res = base_socket_inner(raw_socket, SIO_BASE_HANDLE);
549 if let Ok(base_socket) = res {
550 return Ok(base_socket);
551 }
552
553 for &ioctl in &[SIO_BSP_HANDLE_SELECT, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE] {
554 if let Ok(base_socket) = base_socket_inner(raw_socket, ioctl) {
555 if base_socket != raw_socket {
556 return Ok(base_socket);
557 }
558 }
559 }
560
561 // res is an error, then there must be an os error
562 Err(io::Error::from_raw_os_error(res.unwrap_err()))
563 }
564
base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32>565 fn base_socket_inner(raw_socket: RawSocket, control_code: u32) -> Result<RawSocket, i32> {
566 let mut base_socket: RawSocket = 0;
567 let mut bytes_returned: u32 = 0;
568 unsafe {
569 if WSAIoctl(
570 raw_socket as usize,
571 control_code,
572 null_mut(),
573 0,
574 (&mut base_socket as *mut RawSocket).cast::<c_void>(),
575 size_of::<RawSocket>() as u32,
576 &mut bytes_returned,
577 null_mut(),
578 None,
579 ) != SOCKET_ERROR
580 {
581 Ok(base_socket)
582 } else {
583 // Returns the error status for the last Windows Sockets operation that failed.
584 Err(WSAGetLastError())
585 }
586 }
587 }
588
589 /// Interests convert to flags.
interests_to_afd_flags(interests: Interest) -> u32590 fn interests_to_afd_flags(interests: Interest) -> u32 {
591 let mut flags = 0;
592
593 // Sets readable flags.
594 if interests.is_readable() {
595 flags |= READABLE_FLAGS | READ_CLOSED_FLAGS | ERROR_FLAGS;
596 }
597
598 // Sets writable flags.
599 if interests.is_writable() {
600 flags |= WRITABLE_FLAGS | WRITE_CLOSED_FLAGS | ERROR_FLAGS;
601 }
602
603 flags
604 }
605
606 /// Converts the pointer to a `SockState` into a raw pointer.
into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void607 fn into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> *mut c_void {
608 let overlapped_ptr: *const Mutex<SockState> =
609 unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) };
610 overlapped_ptr as *mut _
611 }
612
613 /// Convert a raw overlapped pointer into a reference to `SockState`.
from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>>614 fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>> {
615 let sock_ptr: *const Mutex<SockState> = ptr as *const _;
616 unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) }
617 }
618