1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::future::Future;
15 use std::marker::PhantomData;
16 use std::mem::ManuallyDrop;
17 use std::ops::Deref;
18 use std::ptr::NonNull;
19 use std::task::{RawWaker, RawWakerVTable, Waker};
20
21 use crate::task::raw::Header;
22
get_header_by_raw_ptr(ptr: *const ()) -> NonNull<Header>23 unsafe fn get_header_by_raw_ptr(ptr: *const ()) -> NonNull<Header> {
24 let header = ptr as *mut Header;
25 let non_header = NonNull::new(header);
26 if let Some(non_header) = non_header {
27 non_header
28 } else {
29 panic!("task header is null");
30 }
31 }
32
clone<T>(ptr: *const ()) -> RawWaker where T: Future,33 unsafe fn clone<T>(ptr: *const ()) -> RawWaker
34 where
35 T: Future,
36 {
37 let header = ptr.cast::<Header>();
38 (*header).state.inc_ref();
39 raw_waker::<T>(header)
40 }
41
wake(ptr: *const ())42 unsafe fn wake(ptr: *const ()) {
43 let header = get_header_by_raw_ptr(ptr);
44 let vir_tble = header.as_ref().vtable;
45 (vir_tble.schedule)(header, true);
46 }
47
wake_by_ref(ptr: *const ())48 unsafe fn wake_by_ref(ptr: *const ()) {
49 let header = get_header_by_raw_ptr(ptr);
50 let vir_tble = header.as_ref().vtable;
51 (vir_tble.schedule)(header, false);
52 }
53
drop(ptr: *const ())54 unsafe fn drop(ptr: *const ()) {
55 let header = get_header_by_raw_ptr(ptr);
56 let vir_tble = header.as_ref().vtable;
57 (vir_tble.drop_ref)(header);
58 }
59
raw_waker<T>(header: *const Header) -> RawWaker where T: Future,60 fn raw_waker<T>(header: *const Header) -> RawWaker
61 where
62 T: Future,
63 {
64 let ptr = header.cast::<()>();
65 let raw_waker_ref = &RawWakerVTable::new(clone::<T>, wake, wake_by_ref, drop);
66 RawWaker::new(ptr, raw_waker_ref)
67 }
68
69 pub(crate) struct WakerRefHeader<'a> {
70 waker: ManuallyDrop<Waker>,
71 _field: PhantomData<&'a Header>,
72 }
73
74 impl WakerRefHeader<'_> {
new<T>(header: &Header) -> WakerRefHeader<'_> where T: Future,75 pub(crate) fn new<T>(header: &Header) -> WakerRefHeader<'_>
76 where
77 T: Future,
78 {
79 let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker::<T>(header))) };
80
81 WakerRefHeader {
82 waker,
83 _field: PhantomData,
84 }
85 }
86 }
87
88 impl Deref for WakerRefHeader<'_> {
89 type Target = Waker;
90
deref(&self) -> &Self::Target91 fn deref(&self) -> &Self::Target {
92 &self.waker
93 }
94 }
95