1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef OHOS_DISTRIBUTED_DATA_FRAMEWORKS_COMMON_PRIORITY_QUEUE_H
17 #define OHOS_DISTRIBUTED_DATA_FRAMEWORKS_COMMON_PRIORITY_QUEUE_H
18 #include <condition_variable>
19 #include <map>
20 #include <memory>
21 #include <mutex>
22 #include <queue>
23 #include <set>
24 #include <shared_mutex>
25 #include <functional>
26 namespace OHOS {
27 template<typename _Tsk, typename _Tme, typename _Tid>
28 class PriorityQueue {
29 public:
30     struct PQMatrix {
31         _Tsk task_;
32         _Tid id_;
33         bool removed = false;
PQMatrixPQMatrix34         PQMatrix(_Tsk task, _Tid id) : task_(task), id_(id) {}
35     };
36     using TskIndex = typename std::map<_Tme, PQMatrix>::iterator;
37     using TskUpdater = typename std::function<std::pair<bool, _Tme>(_Tsk &element)>;
38 
39     PriorityQueue(const _Tsk &task, TskUpdater updater = nullptr)
40         : INVALID_TSK(std::move(task)), updater_(std::move(updater))
41     {
42         if (!updater_) {
43             updater_ = [](_Tsk &) { return std::pair{false, _Tme()};};
44         }
45     }
Pop()46     _Tsk Pop()
47     {
48         std::unique_lock<decltype(pqMtx_)> lock(pqMtx_);
49         while (!tasks_.empty()) {
50             auto waitTme = tasks_.begin()->first;
51             if (waitTme > std::chrono::steady_clock::now()) {
52                 popCv_.wait_until(lock, waitTme);
53                 continue;
54             }
55             auto temp = tasks_.begin();
56             auto id = temp->second.id_;
57             running_.emplace(id, temp->second);
58             auto res = std::move(temp->second.task_);
59             tasks_.erase(temp);
60             indexes_.erase(id);
61             return res;
62         }
63         return INVALID_TSK;
64     }
65 
Push(_Tsk tsk,_Tid id,_Tme tme)66     bool Push(_Tsk tsk, _Tid id, _Tme tme)
67     {
68         std::unique_lock<std::mutex> lock(pqMtx_);
69         if (!tsk.Valid()) {
70             return false;
71         }
72         auto temp = tasks_.emplace(tme, PQMatrix(std::move(tsk), id));
73         indexes_.emplace(id, temp);
74         popCv_.notify_all();
75         return true;
76     }
77 
Size()78     size_t Size()
79     {
80         std::lock_guard<std::mutex> lock(pqMtx_);
81         return tasks_.size();
82     }
83 
Find(_Tid id)84     _Tsk Find(_Tid id)
85     {
86         std::unique_lock<decltype(pqMtx_)> lock(pqMtx_);
87         if (indexes_.find(id) != indexes_.end()) {
88             return indexes_[id]->second.task_;
89         }
90         return INVALID_TSK;
91     }
92 
Update(_Tid id,TskUpdater updater)93     bool Update(_Tid id, TskUpdater updater)
94     {
95         std::unique_lock<decltype(pqMtx_)> lock(pqMtx_);
96         auto index = indexes_.find(id);
97         if (index != indexes_.end()) {
98             auto [repeat, time] = updater(index->second->second.task_);
99             auto matrix = std::move(index->second->second);
100             tasks_.erase(index->second);
101             index->second = tasks_.emplace(time, std::move(matrix));
102             popCv_.notify_all();
103             return true;
104         }
105 
106         auto running = running_.find(id);
107         if (running != running_.end()) {
108             auto [repeat, time] = updater((*running).second.task_);
109             return repeat;
110         }
111 
112         return false;
113     }
114 
Remove(_Tid id,bool wait)115     bool Remove(_Tid id, bool wait)
116     {
117         std::unique_lock<decltype(pqMtx_)> lock(pqMtx_);
118         auto it = running_.find(id);
119         if (it != running_.end()) {
120             it->second.removed = true;
121         }
122         removeCv_.wait(lock, [this, id, wait] {
123             return !wait || running_.find(id) == running_.end();
124         });
125         auto index = indexes_.find(id);
126         if (index == indexes_.end()) {
127             return false;
128         }
129         tasks_.erase(index->second);
130         indexes_.erase(index);
131         popCv_.notify_all();
132         return true;
133     }
134 
Clean()135     void Clean()
136     {
137         std::unique_lock<decltype(pqMtx_)> lock(pqMtx_);
138         indexes_.clear();
139         tasks_.clear();
140         popCv_.notify_all();
141     }
142 
Finish(_Tid id)143     void Finish(_Tid id)
144     {
145         std::unique_lock<decltype(pqMtx_)> lock(pqMtx_);
146         auto it = running_.find(id);
147         if (it == running_.end()) {
148             return;
149         }
150         if (!it->second.removed) {
151             auto [repeat, time] = updater_(it->second.task_);
152             if (repeat) {
153                 indexes_.emplace(id, tasks_.emplace(time, std::move(it->second)));
154             }
155         }
156         running_.erase(it);
157         removeCv_.notify_all();
158     }
159 
160 private:
161     const _Tsk INVALID_TSK;
162     std::mutex pqMtx_;
163     std::condition_variable popCv_;
164     std::condition_variable removeCv_;
165     std::multimap<_Tme, PQMatrix> tasks_;
166     std::map<_Tid, PQMatrix> running_;
167     std::map<_Tid, TskIndex> indexes_;
168     TskUpdater updater_;
169 };
170 } // namespace OHOS
171 #endif //OHOS_DISTRIBUTED_DATA_FRAMEWORKS_COMMON_PRIORITY_QUEUE_H
172