1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::future::Future;
15 use std::marker::PhantomData;
16 use std::mem::ManuallyDrop;
17 use std::ops::Deref;
18 use std::ptr::NonNull;
19 use std::task::{RawWaker, RawWakerVTable, Waker};
20 
21 use crate::task::raw::Header;
22 
get_header_by_raw_ptr(ptr: *const ()) -> NonNull<Header>23 unsafe fn get_header_by_raw_ptr(ptr: *const ()) -> NonNull<Header> {
24     let header = ptr as *mut Header;
25     let non_header = NonNull::new(header);
26     if let Some(non_header) = non_header {
27         non_header
28     } else {
29         panic!("task header is null");
30     }
31 }
32 
clone<T>(ptr: *const ()) -> RawWaker where T: Future,33 unsafe fn clone<T>(ptr: *const ()) -> RawWaker
34 where
35     T: Future,
36 {
37     let header = ptr.cast::<Header>();
38     (*header).state.inc_ref();
39     raw_waker::<T>(header)
40 }
41 
wake(ptr: *const ())42 unsafe fn wake(ptr: *const ()) {
43     let header = get_header_by_raw_ptr(ptr);
44     let vir_tble = header.as_ref().vtable;
45     (vir_tble.schedule)(header, true);
46 }
47 
wake_by_ref(ptr: *const ())48 unsafe fn wake_by_ref(ptr: *const ()) {
49     let header = get_header_by_raw_ptr(ptr);
50     let vir_tble = header.as_ref().vtable;
51     (vir_tble.schedule)(header, false);
52 }
53 
drop(ptr: *const ())54 unsafe fn drop(ptr: *const ()) {
55     let header = get_header_by_raw_ptr(ptr);
56     let vir_tble = header.as_ref().vtable;
57     (vir_tble.drop_ref)(header);
58 }
59 
raw_waker<T>(header: *const Header) -> RawWaker where T: Future,60 fn raw_waker<T>(header: *const Header) -> RawWaker
61 where
62     T: Future,
63 {
64     let ptr = header.cast::<()>();
65     let raw_waker_ref = &RawWakerVTable::new(clone::<T>, wake, wake_by_ref, drop);
66     RawWaker::new(ptr, raw_waker_ref)
67 }
68 
69 pub(crate) struct WakerRefHeader<'a> {
70     waker: ManuallyDrop<Waker>,
71     _field: PhantomData<&'a Header>,
72 }
73 
74 impl WakerRefHeader<'_> {
new<T>(header: &Header) -> WakerRefHeader<'_> where T: Future,75     pub(crate) fn new<T>(header: &Header) -> WakerRefHeader<'_>
76     where
77         T: Future,
78     {
79         let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker::<T>(header))) };
80 
81         WakerRefHeader {
82             waker,
83             _field: PhantomData,
84         }
85     }
86 }
87 
88 impl Deref for WakerRefHeader<'_> {
89     type Target = Waker;
90 
deref(&self) -> &Self::Target91     fn deref(&self) -> &Self::Target {
92         &self.waker
93     }
94 }
95