1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "threading/parallel_task_queue.h"
17 
18 #include <algorithm>
19 #include <condition_variable>
20 #include <mutex>
21 
22 #include <base/containers/array_view.h>
23 #include <base/containers/iterator.h>
24 #include <base/containers/refcnt_ptr.h>
25 #include <base/containers/type_traits.h>
26 #include <base/containers/unique_ptr.h>
27 #include <base/containers/unordered_map.h>
28 #include <base/containers/vector.h>
29 #include <core/log.h>
30 #include <core/namespace.h>
31 #include <core/threading/intf_thread_pool.h>
32 
33 CORE_BEGIN_NAMESPACE()
34 using BASE_NS::array_view;
35 using BASE_NS::unordered_map;
36 using BASE_NS::vector;
37 
38 struct ParallelTaskQueue::TaskState {
39     unordered_map<uint64_t, bool> finished;
40     std::condition_variable cv;
41     std::mutex mutex;
42 };
43 
44 class ParallelTaskQueue::Task final : public IThreadPool::ITask {
45 public:
46     explicit Task(TaskState& state, IThreadPool::ITask& task, uint64_t id);
47 
48     void operator()() override;
49 
50 protected:
51     void Destroy() override;
52 
53 private:
54     TaskState& state_;
55     IThreadPool::ITask& task_;
56     uint64_t id_;
57 };
58 
Task(TaskState & state,IThreadPool::ITask & task,uint64_t id)59 ParallelTaskQueue::Task::Task(TaskState& state, IThreadPool::ITask& task, uint64_t id)
60     : state_(state), task_(task), id_(id)
61 {}
62 
operator ()()63 void ParallelTaskQueue::Task::operator()()
64 {
65     // Run task.
66     task_();
67 
68     // Mark task as completed.
69     std::unique_lock lock(state_.mutex);
70     state_.finished[id_] = true;
71 
72     // Notify that there is completed task.
73     state_.cv.notify_one();
74 }
75 
Destroy()76 void ParallelTaskQueue::Task::Destroy()
77 {
78     delete this;
79 }
80 
81 // -- Parallel task queue.
ParallelTaskQueue(const IThreadPool::Ptr & threadPool)82 ParallelTaskQueue::ParallelTaskQueue(const IThreadPool::Ptr& threadPool) : TaskQueue(threadPool) {}
83 
~ParallelTaskQueue()84 ParallelTaskQueue::~ParallelTaskQueue()
85 {
86     Wait();
87 }
88 
Submit(uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)89 void ParallelTaskQueue::Submit(uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
90 {
91     CORE_ASSERT(std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier) == tasks_.cend());
92 
93     tasks_.emplace_back(taskIdentifier, std::move(task));
94 }
95 
SubmitAfter(uint64_t afterIdentifier,uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)96 void ParallelTaskQueue::SubmitAfter(uint64_t afterIdentifier, uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
97 {
98     CORE_ASSERT(std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier) == tasks_.cend());
99 
100     auto it = std::find(tasks_.begin(), tasks_.end(), afterIdentifier);
101     if (it != tasks_.end()) {
102         Entry entry(taskIdentifier, std::move(task));
103         entry.dependencies.push_back(afterIdentifier);
104 
105         tasks_.push_back(std::move(entry));
106     } else {
107         tasks_.emplace_back(taskIdentifier, std::move(task));
108     }
109 }
110 
SubmitAfter(array_view<const uint64_t> afterIdentifiers,uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)111 void ParallelTaskQueue::SubmitAfter(
112     array_view<const uint64_t> afterIdentifiers, uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
113 {
114     if (std::all_of(
115             afterIdentifiers.cbegin(), afterIdentifiers.cend(), [&tasks = tasks_](const uint64_t afterIdentifier) {
116                 return std::any_of(tasks.cbegin(), tasks.cend(),
117                     [afterIdentifier](const TaskQueue::Entry& entry) { return entry.identifier == afterIdentifier; });
118             })) {
119         Entry entry(taskIdentifier, std::move(task));
120         entry.dependencies.insert(entry.dependencies.cend(), afterIdentifiers.begin(), afterIdentifiers.end());
121 
122         tasks_.push_back(std::move(entry));
123     } else {
124         tasks_.emplace_back(taskIdentifier, std::move(task));
125     }
126 }
127 
Remove(uint64_t taskIdentifier)128 void ParallelTaskQueue::Remove(uint64_t taskIdentifier)
129 {
130     auto it = std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier);
131     if (it != tasks_.cend()) {
132         tasks_.erase(it);
133     }
134 }
135 
Clear()136 void ParallelTaskQueue::Clear()
137 {
138     Wait();
139     tasks_.clear();
140 }
141 
QueueTasks(vector<size_t> & waiting,TaskState & state)142 void ParallelTaskQueue::QueueTasks(vector<size_t>& waiting, TaskState& state)
143 {
144     if (waiting.empty()) {
145         // No more tasks to proecss.
146         return;
147     }
148 
149     for (vector<size_t>::const_iterator it = waiting.cbegin(); it != waiting.cend();) {
150         // Entry to handle.
151         Entry& entry = tasks_[*it];
152 
153         // Can run this task?
154         bool canRun = true;
155         for (const auto& dep : entry.dependencies) {
156             if (!state.finished.contains(dep)) {
157                 // Task that is marked as dependency is not executed yet.
158                 canRun = false;
159                 break;
160             }
161         }
162 
163         if (canRun) {
164             // This task can be executed.
165             // Remove task from waiting list.
166             it = waiting.erase(it);
167 
168             // Push to execution queue.
169             threadPool_->PushNoWait(IThreadPool::ITask::Ptr { new Task(state, *entry.task, entry.identifier) });
170         } else {
171             ++it;
172         }
173     }
174 }
175 
Execute()176 void ParallelTaskQueue::Execute()
177 {
178 #if (CORE_VALIDATION_ENABLED == 1)
179     // NOTE: Check the integrity of the task queue (no circular deps etc.)
180 #endif
181     vector<size_t> waiting;
182     waiting.resize(tasks_.size());
183     for (size_t i = 0; i < tasks_.size(); ++i) {
184         waiting[i] = i;
185     }
186 
187     TaskState state;
188     state.finished.reserve(tasks_.size());
189 
190     {
191         // Keep on pushing tasks to queue until all done.
192         std::unique_lock lock(state.mutex);
193         state.cv.wait(lock, [this, &waiting, &state]() {
194             // Push new tasks to queue.
195             QueueTasks(waiting, state);
196             return state.finished.size() == tasks_.size();
197         });
198     }
199 }
200 CORE_END_NAMESPACE()
201