1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "threading/task_queue_factory.h"
17
18 #include <atomic>
19 #include <condition_variable>
20 #include <cstddef>
21 #include <memory>
22 #include <mutex>
23 #include <queue>
24 #include <thread>
25
26 #include <base/containers/array_view.h>
27 #include <base/containers/iterator.h>
28 #include <base/containers/type_traits.h>
29 #include <base/containers/unique_ptr.h>
30 #include <base/util/uid.h>
31 #include <core/log.h>
32 #include <core/threading/intf_thread_pool.h>
33
34 #include "threading/dispatcher_impl.h"
35 #include "threading/parallel_impl.h"
36 #include "threading/sequential_impl.h"
37
38 #ifdef PLATFORM_HAS_JAVA
39 #include <os/java/java_internal.h>
40 #endif
41
42 CORE_BEGIN_NAMESPACE()
43 using BASE_NS::array_view;
44 using BASE_NS::make_unique;
45 using BASE_NS::move;
46 using BASE_NS::unique_ptr;
47
48 namespace {
49 #ifdef PLATFORM_HAS_JAVA
50 /** RAII class for handling thread setup/release. */
51 class JavaThreadContext final {
52 public:
JavaThreadContext()53 JavaThreadContext()
54 {
55 JNIEnv* env = nullptr;
56 javaVm_ = java_internal::GetJavaVM();
57
58 #ifndef NDEBUG
59 // Check that the thread was not already attached.
60 // It's not really a problem as another attach is a no-op, but we will be detaching the
61 // thread later and it may be unexpected for the user.
62 jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
63 CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
64 #endif
65
66 javaVm_->AttachCurrentThread(&env, nullptr);
67 }
68
~JavaThreadContext()69 ~JavaThreadContext()
70 {
71 javaVm_->DetachCurrentThread();
72 }
73 JavaVM* javaVm_ { nullptr };
74 };
75 #endif // PLATFORM_HAS_JAVA
76
77 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
78 class TaskResult final : public IThreadPool::IResult {
79 public:
80 // Task state which can be waited and marked as done.
81 class State {
82 public:
Done()83 void Done()
84 {
85 {
86 auto lock = std::lock_guard(mutex_);
87 done_ = true;
88 }
89 cv_.notify_all();
90 }
91
Wait()92 void Wait()
93 {
94 auto lock = std::unique_lock(mutex_);
95 cv_.wait(lock, [this]() { return done_; });
96 }
97
IsDone() const98 bool IsDone() const
99 {
100 auto lock = std::lock_guard(mutex_);
101 return done_;
102 }
103
104 private:
105 mutable std::mutex mutex_;
106 std::condition_variable cv_;
107 bool done_ { false };
108 };
109
TaskResult(std::shared_ptr<State> && future)110 explicit TaskResult(std::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
111
Wait()112 void Wait() final
113 {
114 if (future_) {
115 future_->Wait();
116 }
117 }
IsDone() const118 bool IsDone() const final
119 {
120 if (future_) {
121 return future_->IsDone();
122 }
123 return true;
124 }
125
126 protected:
Destroy()127 void Destroy() final
128 {
129 delete this;
130 }
131
132 private:
133 std::shared_ptr<State> future_;
134 };
135
136 // -- ThreadPool
137 class ThreadPool final : public IThreadPool {
138 public:
ThreadPool(size_t threadCount)139 explicit ThreadPool(size_t threadCount)
140 : threadCount_(threadCount), threads_(make_unique<ThreadContext[]>(threadCount))
141 {
142 CORE_ASSERT(threads_);
143
144 // Create thread containers.
145 auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
146 for (ThreadContext& context : threads) {
147 // Set-up thread function.
148 context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
149 }
150 }
151
152 ThreadPool(const ThreadPool&) = delete;
153 ThreadPool(ThreadPool&&) = delete;
154 ThreadPool& operator=(const ThreadPool&) = delete;
155 ThreadPool& operator=(ThreadPool&&) = delete;
156
Push(ITask::Ptr function)157 IResult::Ptr Push(ITask::Ptr function) override
158 {
159 auto taskState = std::make_shared<TaskResult::State>();
160 if (taskState) {
161 if (function) {
162 {
163 std::lock_guard lock(mutex_);
164 q_.Push(Task(move(function), taskState));
165 }
166 cv_.notify_one();
167 } else {
168 // mark as done if the there was no function.
169 taskState->Done();
170 }
171 }
172 return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
173 }
174
PushNoWait(ITask::Ptr function)175 void PushNoWait(ITask::Ptr function) override
176 {
177 if (function) {
178 {
179 std::lock_guard lock(mutex_);
180 q_.Push(Task(move(function)));
181 }
182 cv_.notify_one();
183 }
184 }
185
GetNumberOfThreads() const186 uint32_t GetNumberOfThreads() const override
187 {
188 return static_cast<uint32_t>(threadCount_);
189 }
190
191 // IInterface
GetInterface(const BASE_NS::Uid & uid) const192 const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
193 {
194 if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
195 return this;
196 }
197 return nullptr;
198 }
199
GetInterface(const BASE_NS::Uid & uid)200 IInterface* GetInterface(const BASE_NS::Uid& uid) override
201 {
202 if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
203 return this;
204 }
205 return nullptr;
206 }
207
Ref()208 void Ref() override
209 {
210 refcnt_.fetch_add(1, std::memory_order_relaxed);
211 }
212
Unref()213 void Unref() override
214 {
215 if (std::atomic_fetch_sub_explicit(&refcnt_, 1, std::memory_order_release) == 1) {
216 std::atomic_thread_fence(std::memory_order_acquire);
217 delete this;
218 }
219 }
220
221 protected:
~ThreadPool()222 ~ThreadPool() final
223 {
224 Stop(true);
225 }
226
227 private:
228 // Helper which holds a pointer to a queued task function and the result state.
229 struct Task {
230 ITask::Ptr function_;
231 std::shared_ptr<TaskResult::State> state_;
232
233 ~Task() = default;
234 Task() = default;
Task__anondcf869ec0110::ThreadPool::Task235 explicit Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state)
236 : function_(move(function)), state_(CORE_NS::move(state))
237 {
238 CORE_ASSERT(this->function_ && this->state_);
239 }
Task__anondcf869ec0110::ThreadPool::Task240 explicit Task(ITask::Ptr&& function) : function_(move(function))
241 {
242 CORE_ASSERT(this->function_);
243 }
244 Task(Task&&) = default;
245 Task& operator=(Task&&) = default;
246 Task(const Task&) = delete;
247 Task& operator=(const Task&) = delete;
248
operator ()__anondcf869ec0110::ThreadPool::Task249 void operator()() const
250 {
251 (*function_)();
252 if (state_) {
253 state_->Done();
254 }
255 }
256 };
257
258 template<typename T>
259 class Queue {
260 public:
Push(T && value)261 bool Push(T&& value)
262 {
263 q_.push(move(value));
264 return true;
265 }
266
Pop(T & v)267 bool Pop(T& v)
268 {
269 if (q_.empty()) {
270 v = {};
271 return false;
272 }
273 v = CORE_NS::move(q_.front());
274 q_.pop();
275 return true;
276 }
277
278 private:
279 std::queue<T> q_;
280 };
281
282 struct ThreadContext {
283 std::thread thread;
284 bool exit { false };
285 };
286
Clear()287 void Clear()
288 {
289 Task f;
290 std::lock_guard lock(mutex_);
291 while (q_.Pop(f)) {
292 // Intentionally empty.
293 }
294 }
295
296 // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
297 // If this doesn't change the class can be simplified a bit.
Stop(bool waitForAllTasksToComplete)298 void Stop(bool waitForAllTasksToComplete)
299 {
300 if (isStop_) {
301 return;
302 }
303 if (waitForAllTasksToComplete) {
304 // Wait all tasks to complete before returning.
305 if (isDone_) {
306 return;
307 }
308 std::lock_guard lock(mutex_);
309 isDone_ = true;
310 } else {
311 isStop_ = true;
312
313 // Ask all the threads to stop and not process any more tasks.
314 auto threads = array_view(threads_.get(), threadCount_);
315 {
316 auto lock = std::lock_guard(mutex_);
317 for (auto& context : threads) {
318 context.exit = true;
319 }
320 }
321 Clear();
322 }
323
324 // Trigger all waiting threads.
325 cv_.notify_all();
326
327 // Wait for all threads to finish.
328 auto threads = array_view(threads_.get(), threadCount_);
329 for (auto& context : threads) {
330 if (context.thread.joinable()) {
331 context.thread.join();
332 }
333 }
334
335 Clear();
336 }
337
ThreadProc(ThreadContext & context)338 void ThreadProc(ThreadContext& context)
339 {
340 #ifdef PLATFORM_HAS_JAVA
341 // RAII class for handling thread setup/release.
342 JavaThreadContext javaContext;
343 #endif
344
345 // Get function to process.
346 Task func;
347 bool isPop = [this, &func]() {
348 std::lock_guard lock(mutex_);
349 return q_.Pop(func);
350 }();
351
352 while (true) {
353 while (isPop) {
354 // Run task function.
355 func();
356
357 // If the thread is wanted to stop, return even if the queue is not empty yet.
358 std::lock_guard lock(mutex_);
359 if (context.exit) {
360 return;
361 }
362
363 // Get next function.
364 isPop = q_.Pop(func);
365 }
366
367 // The queue is empty here, wait for the next task.
368 std::unique_lock lock(mutex_);
369
370 // Try to wait for next task to process.
371 cv_.wait(lock, [this, &func, &isPop, &context]() {
372 isPop = q_.Pop(func);
373 return isPop || isDone_ || context.exit;
374 });
375
376 if (!isPop) {
377 return;
378 }
379 }
380 }
381
382 size_t threadCount_ { 0 };
383 unique_ptr<ThreadContext[]> threads_;
384
385 Queue<Task> q_;
386 bool isDone_ { false };
387 bool isStop_ { false };
388
389 std::mutex mutex_;
390 std::condition_variable cv_;
391 std::atomic<int32_t> refcnt_ { 0 };
392 };
393 } // namespace
394
GetNumberOfCores() const395 uint32_t TaskQueueFactory::GetNumberOfCores() const
396 {
397 uint32_t result = std::thread::hardware_concurrency();
398 if (result == 0) {
399 // If not detectable, default to 4.
400 result = 4;
401 }
402
403 return result;
404 }
405
CreateThreadPool(const uint32_t threadCount) const406 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
407 {
408 return IThreadPool::Ptr { new ThreadPool(threadCount) };
409 }
410
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const411 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
412 {
413 return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
414 }
415
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const416 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
417 {
418 return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
419 }
420
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const421 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
422 {
423 return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
424 }
425
426 // IInterface
GetInterface(const BASE_NS::Uid & uid) const427 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
428 {
429 if (uid == ITaskQueueFactory::UID) {
430 return static_cast<const ITaskQueueFactory*>(this);
431 }
432 return nullptr;
433 }
434
GetInterface(const BASE_NS::Uid & uid)435 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
436 {
437 if (uid == ITaskQueueFactory::UID) {
438 return static_cast<ITaskQueueFactory*>(this);
439 }
440 return nullptr;
441 }
442
Ref()443 void TaskQueueFactory::Ref() {}
444
Unref()445 void TaskQueueFactory::Unref() {}
446 CORE_END_NAMESPACE()
447