1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::ffi::c_void;
15 use std::fs::File;
16 use std::mem::{size_of, zeroed};
17 use std::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle};
18 use std::ptr::null_mut;
19 use std::sync::atomic::{AtomicUsize, Ordering};
20 use std::sync::{Arc, Mutex};
21 use std::{fmt, io};
22 
23 use crate::sys::winapi::{
24     NtCreateFile, NtDeviceIoControlFile, RtlNtStatusToDosError, SetFileCompletionNotificationModes,
25     FILE_OPEN, FILE_SHARE_READ, FILE_SHARE_WRITE, FILE_SKIP_SET_EVENT_ON_HANDLE, HANDLE,
26     INVALID_HANDLE_VALUE, IO_STATUS_BLOCK, IO_STATUS_BLOCK_0, NTSTATUS, OBJECT_ATTRIBUTES,
27     STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS, SYNCHRONIZE, UNICODE_STRING,
28 };
29 use crate::sys::windows::iocp::CompletionPort;
30 
31 pub const POLL_RECEIVE: u32 = 0x0001;
32 pub const POLL_RECEIVE_EXPEDITED: u32 = 0x0002;
33 pub const POLL_SEND: u32 = 0x0004;
34 pub const POLL_DISCONNECT: u32 = 0x0008;
35 pub const POLL_ABORT: u32 = 0x0010;
36 pub const POLL_LOCAL_CLOSE: u32 = 0x0020;
37 pub const POLL_ACCEPT: u32 = 0x0080;
38 pub const POLL_CONNECT_FAIL: u32 = 0x0100;
39 
40 pub const ALL_EVENTS: u32 = POLL_RECEIVE
41     | POLL_RECEIVE_EXPEDITED
42     | POLL_SEND
43     | POLL_DISCONNECT
44     | POLL_ACCEPT
45     | POLL_LOCAL_CLOSE
46     | POLL_ABORT
47     | POLL_CONNECT_FAIL;
48 
49 const AFD_ATTRIBUTES: OBJECT_ATTRIBUTES = OBJECT_ATTRIBUTES {
50     Length: size_of::<OBJECT_ATTRIBUTES>() as u32,
51     RootDirectory: 0,
52     ObjectName: &OBJ_NAME as *const _ as *mut _,
53     Attributes: 0,
54     SecurityDescriptor: null_mut(),
55     SecurityQualityOfService: null_mut(),
56 };
57 const OBJ_NAME: UNICODE_STRING = UNICODE_STRING {
58     Length: (AFD_HELPER_NAME.len() * size_of::<u16>()) as u16,
59     MaximumLength: (AFD_HELPER_NAME.len() * size_of::<u16>()) as u16,
60     Buffer: AFD_HELPER_NAME.as_ptr() as *mut _,
61 };
62 const AFD_HELPER_NAME: &[u16] = &[
63     '\\' as _, 'D' as _, 'e' as _, 'v' as _, 'i' as _, 'c' as _, 'e' as _, '\\' as _, 'A' as _,
64     'f' as _, 'd' as _, '\\' as _, 'Y' as _, 'l' as _, 'o' as _, 'n' as _, 'g' as _,
65 ];
66 
67 static NEXT_TOKEN: AtomicUsize = AtomicUsize::new(0);
68 const IOCTL_AFD_POLL: u32 = 0x00012024;
69 
70 #[link(name = "ntdll")]
71 extern "system" {
NtCancelIoFileEx( FileHandle: HANDLE, IoRequestToCancel: *mut IO_STATUS_BLOCK, IoStatusBlock: *mut IO_STATUS_BLOCK, ) -> NTSTATUS72     fn NtCancelIoFileEx(
73         FileHandle: HANDLE,
74         IoRequestToCancel: *mut IO_STATUS_BLOCK,
75         IoStatusBlock: *mut IO_STATUS_BLOCK,
76     ) -> NTSTATUS;
77 }
78 
79 /// Asynchronous file descriptor
80 /// Implementing a single file handle to monitor multiple Io operations using
81 /// the IO multiplexing model.
82 #[derive(Debug)]
83 pub struct Afd {
84     fd: File,
85 }
86 
87 impl Afd {
88     /// Creates a new Afd and add it to CompletionPort
new(cp: &CompletionPort) -> io::Result<Afd>89     fn new(cp: &CompletionPort) -> io::Result<Afd> {
90         let mut afd_device_handle: HANDLE = INVALID_HANDLE_VALUE;
91         let mut io_status_block = IO_STATUS_BLOCK {
92             Anonymous: IO_STATUS_BLOCK_0 { Status: 0 },
93             Information: 0,
94         };
95 
96         let fd = unsafe {
97             let status = NtCreateFile(
98                 &mut afd_device_handle as *mut _,
99                 SYNCHRONIZE,
100                 &AFD_ATTRIBUTES as *const _ as *mut _,
101                 &mut io_status_block,
102                 null_mut(),
103                 0,
104                 FILE_SHARE_READ | FILE_SHARE_WRITE,
105                 FILE_OPEN,
106                 0,
107                 null_mut(),
108                 0,
109             );
110 
111             if status != STATUS_SUCCESS {
112                 let raw_error = io::Error::from_raw_os_error(RtlNtStatusToDosError(status) as i32);
113 
114                 let msg = format!("Failed to open \\Device\\Afd\\Ylong: {raw_error}");
115                 return Err(io::Error::new(raw_error.kind(), msg));
116             }
117 
118             File::from_raw_handle(afd_device_handle as RawHandle)
119         };
120 
121         let token = NEXT_TOKEN.fetch_add(2, Ordering::Relaxed) + 2;
122         let afd = Afd { fd };
123         // Add Afd to CompletionPort
124         cp.add_handle(token, &afd.fd)?;
125 
126         syscall!(
127             SetFileCompletionNotificationModes(
128                 afd_device_handle,
129                 FILE_SKIP_SET_EVENT_ON_HANDLE as u8,
130             ),
131             afd
132         )
133     }
134 
135     /// System call
poll( &self, info: &mut AfdPollInfo, iosb: *mut IO_STATUS_BLOCK, overlapped: *mut c_void, ) -> io::Result<bool>136     pub(crate) unsafe fn poll(
137         &self,
138         info: &mut AfdPollInfo,
139         iosb: *mut IO_STATUS_BLOCK,
140         overlapped: *mut c_void,
141     ) -> io::Result<bool> {
142         let afd_info = (info as *mut AfdPollInfo).cast::<c_void>();
143         (*iosb).Anonymous.Status = STATUS_PENDING;
144 
145         let status = NtDeviceIoControlFile(
146             self.fd.as_raw_handle() as HANDLE,
147             0,
148             None,
149             overlapped,
150             iosb,
151             IOCTL_AFD_POLL,
152             afd_info,
153             size_of::<AfdPollInfo>() as u32,
154             afd_info,
155             size_of::<AfdPollInfo>() as u32,
156         );
157 
158         match status {
159             STATUS_SUCCESS => Ok(true),
160             // this is expected.
161             STATUS_PENDING => Ok(false),
162             _ => Err(io::Error::from_raw_os_error(
163                 RtlNtStatusToDosError(status) as i32
164             )),
165         }
166     }
167 
168     /// System call to cancel File HANDLE.
cancel(&self, iosb: *mut IO_STATUS_BLOCK) -> io::Result<()>169     pub(crate) unsafe fn cancel(&self, iosb: *mut IO_STATUS_BLOCK) -> io::Result<()> {
170         if (*iosb).Anonymous.Status != STATUS_PENDING {
171             return Ok(());
172         }
173 
174         let mut cancel_iosb = IO_STATUS_BLOCK {
175             Anonymous: IO_STATUS_BLOCK_0 { Status: 0 },
176             Information: 0,
177         };
178         let status = NtCancelIoFileEx(self.fd.as_raw_handle() as HANDLE, iosb, &mut cancel_iosb);
179         match status {
180             STATUS_SUCCESS | STATUS_NOT_FOUND => Ok(()),
181             _ => Err(io::Error::from_raw_os_error(
182                 RtlNtStatusToDosError(status) as i32
183             )),
184         }
185     }
186 }
187 
188 /// A group which contains Afds.
189 #[derive(Debug)]
190 pub(crate) struct AfdGroup {
191     cp: Arc<CompletionPort>,
192     afd_group: Mutex<Vec<Arc<Afd>>>,
193 }
194 
195 /// Up to 32 Arc points per Afd.
196 const POLL_GROUP__MAX_GROUP_SIZE: usize = 32;
197 
198 impl AfdGroup {
199     /// Creates a new AfdGroup.
new(cp: Arc<CompletionPort>) -> AfdGroup200     pub(crate) fn new(cp: Arc<CompletionPort>) -> AfdGroup {
201         AfdGroup {
202             afd_group: Mutex::new(Vec::new()),
203             cp,
204         }
205     }
206 
207     /// Gets a new point to File.
acquire(&self) -> io::Result<Arc<Afd>>208     pub(crate) fn acquire(&self) -> io::Result<Arc<Afd>> {
209         let mut afd_group = self.afd_group.lock().unwrap();
210 
211         // When the last File has more than 32 Arc Points, creates a new File.
212         // If the vec len is not zero, then last always returns some
213         if afd_group.len() == 0
214             || Arc::strong_count(afd_group.last().unwrap()) > POLL_GROUP__MAX_GROUP_SIZE
215         {
216             let arc = Arc::new(Afd::new(&self.cp)?);
217             afd_group.push(arc);
218         }
219 
220         match afd_group.last() {
221             Some(arc) => Ok(arc.clone()),
222             None => unreachable!(
223                 "Cannot acquire afd, {:#?}, afd_group: {:#?}",
224                 self, afd_group
225             ),
226         }
227     }
228 
229     /// Delete Afd that is no longer in use from AfdGroup.
release_unused_afd(&self)230     pub(crate) fn release_unused_afd(&self) {
231         let mut afd_group = self.afd_group.lock().unwrap();
232         afd_group.retain(|g| Arc::strong_count(g) > 1);
233     }
234 }
235 
236 #[repr(C)]
237 pub struct AfdPollInfo {
238     pub timeout: i64,
239     pub number_of_handles: u32,
240     pub exclusive: u32,
241     pub handles: [AfdPollHandleInfo; 1],
242 }
243 
244 impl AfdPollInfo {
zeroed() -> AfdPollInfo245     pub(crate) fn zeroed() -> AfdPollInfo {
246         unsafe { zeroed() }
247     }
248 }
249 
250 impl fmt::Debug for AfdPollInfo {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result251     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252         f.debug_struct("AfdPollInfo").finish()
253     }
254 }
255 
256 #[repr(C)]
257 #[derive(Debug)]
258 pub struct AfdPollHandleInfo {
259     /// SockState base_socket
260     pub handle: HANDLE,
261     pub events: u32,
262     pub status: NTSTATUS,
263 }
264 
265 unsafe impl Send for AfdPollHandleInfo {}
266