1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "threading/parallel_task_queue.h"
17
18 #include <algorithm>
19 #include <condition_variable>
20 #include <mutex>
21
22 #include <base/containers/array_view.h>
23 #include <base/containers/iterator.h>
24 #include <base/containers/refcnt_ptr.h>
25 #include <base/containers/type_traits.h>
26 #include <base/containers/unique_ptr.h>
27 #include <base/containers/unordered_map.h>
28 #include <base/containers/vector.h>
29 #include <core/log.h>
30 #include <core/namespace.h>
31 #include <core/threading/intf_thread_pool.h>
32
33 CORE_BEGIN_NAMESPACE()
34 using BASE_NS::array_view;
35 using BASE_NS::unordered_map;
36 using BASE_NS::vector;
37
38 struct ParallelTaskQueue::TaskState {
39 unordered_map<uint64_t, bool> finished;
40 std::condition_variable cv;
41 std::mutex mutex;
42 };
43
44 class ParallelTaskQueue::Task final : public IThreadPool::ITask {
45 public:
46 explicit Task(TaskState& state, IThreadPool::ITask& task, uint64_t id);
47
48 void operator()() override;
49
50 protected:
51 void Destroy() override;
52
53 private:
54 TaskState& state_;
55 IThreadPool::ITask& task_;
56 uint64_t id_;
57 };
58
Task(TaskState & state,IThreadPool::ITask & task,uint64_t id)59 ParallelTaskQueue::Task::Task(TaskState& state, IThreadPool::ITask& task, uint64_t id)
60 : state_(state), task_(task), id_(id)
61 {}
62
operator ()()63 void ParallelTaskQueue::Task::operator()()
64 {
65 // Run task.
66 task_();
67
68 // Mark task as completed.
69 std::unique_lock lock(state_.mutex);
70 state_.finished[id_] = true;
71
72 // Notify that there is completed task.
73 state_.cv.notify_one();
74 }
75
Destroy()76 void ParallelTaskQueue::Task::Destroy()
77 {
78 delete this;
79 }
80
81 // -- Parallel task queue.
ParallelTaskQueue(const IThreadPool::Ptr & threadPool)82 ParallelTaskQueue::ParallelTaskQueue(const IThreadPool::Ptr& threadPool) : TaskQueue(threadPool) {}
83
~ParallelTaskQueue()84 ParallelTaskQueue::~ParallelTaskQueue()
85 {
86 Wait();
87 }
88
Submit(uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)89 void ParallelTaskQueue::Submit(uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
90 {
91 CORE_ASSERT(std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier) == tasks_.cend());
92
93 tasks_.emplace_back(taskIdentifier, std::move(task));
94 }
95
SubmitAfter(uint64_t afterIdentifier,uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)96 void ParallelTaskQueue::SubmitAfter(uint64_t afterIdentifier, uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
97 {
98 CORE_ASSERT(std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier) == tasks_.cend());
99
100 auto it = std::find(tasks_.begin(), tasks_.end(), afterIdentifier);
101 if (it != tasks_.end()) {
102 Entry entry(taskIdentifier, std::move(task));
103 entry.dependencies.push_back(afterIdentifier);
104
105 tasks_.push_back(std::move(entry));
106 } else {
107 tasks_.emplace_back(taskIdentifier, std::move(task));
108 }
109 }
110
SubmitAfter(array_view<const uint64_t> afterIdentifiers,uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)111 void ParallelTaskQueue::SubmitAfter(
112 array_view<const uint64_t> afterIdentifiers, uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
113 {
114 if (std::all_of(
115 afterIdentifiers.cbegin(), afterIdentifiers.cend(), [&tasks = tasks_](const uint64_t afterIdentifier) {
116 return std::any_of(tasks.cbegin(), tasks.cend(),
117 [afterIdentifier](const TaskQueue::Entry& entry) { return entry.identifier == afterIdentifier; });
118 })) {
119 Entry entry(taskIdentifier, std::move(task));
120 entry.dependencies.insert(entry.dependencies.cend(), afterIdentifiers.begin(), afterIdentifiers.end());
121
122 tasks_.push_back(std::move(entry));
123 } else {
124 tasks_.emplace_back(taskIdentifier, std::move(task));
125 }
126 }
127
Remove(uint64_t taskIdentifier)128 void ParallelTaskQueue::Remove(uint64_t taskIdentifier)
129 {
130 auto it = std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier);
131 if (it != tasks_.cend()) {
132 tasks_.erase(it);
133 }
134 }
135
Clear()136 void ParallelTaskQueue::Clear()
137 {
138 Wait();
139 tasks_.clear();
140 }
141
QueueTasks(vector<size_t> & waiting,TaskState & state)142 void ParallelTaskQueue::QueueTasks(vector<size_t>& waiting, TaskState& state)
143 {
144 if (waiting.empty()) {
145 // No more tasks to proecss.
146 return;
147 }
148
149 for (vector<size_t>::const_iterator it = waiting.cbegin(); it != waiting.cend();) {
150 // Entry to handle.
151 Entry& entry = tasks_[*it];
152
153 // Can run this task?
154 bool canRun = true;
155 for (const auto& dep : entry.dependencies) {
156 if (!state.finished.contains(dep)) {
157 // Task that is marked as dependency is not executed yet.
158 canRun = false;
159 break;
160 }
161 }
162
163 if (canRun) {
164 // This task can be executed.
165 // Remove task from waiting list.
166 it = waiting.erase(it);
167
168 // Push to execution queue.
169 threadPool_->PushNoWait(IThreadPool::ITask::Ptr { new Task(state, *entry.task, entry.identifier) });
170 } else {
171 ++it;
172 }
173 }
174 }
175
Execute()176 void ParallelTaskQueue::Execute()
177 {
178 #if (CORE_VALIDATION_ENABLED == 1)
179 // NOTE: Check the integrity of the task queue (no circular deps etc.)
180 #endif
181 vector<size_t> waiting;
182 waiting.resize(tasks_.size());
183 for (size_t i = 0; i < tasks_.size(); ++i) {
184 waiting[i] = i;
185 }
186
187 TaskState state;
188 state.finished.reserve(tasks_.size());
189
190 {
191 // Keep on pushing tasks to queue until all done.
192 std::unique_lock lock(state.mutex);
193 state.cv.wait(lock, [this, &waiting, &state]() {
194 // Push new tasks to queue.
195 QueueTasks(waiting, state);
196 return state.finished.size() == tasks_.size();
197 });
198 }
199 }
200 CORE_END_NAMESPACE()
201