1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "threading/task_queue_factory.h"
17 
18 #include <atomic>
19 #include <condition_variable>
20 #include <cstddef>
21 #include <memory>
22 #include <mutex>
23 #include <queue>
24 #include <thread>
25 
26 #include <base/containers/array_view.h>
27 #include <base/containers/iterator.h>
28 #include <base/containers/type_traits.h>
29 #include <base/containers/unique_ptr.h>
30 #include <base/util/uid.h>
31 #include <core/log.h>
32 #include <core/threading/intf_thread_pool.h>
33 
34 #include "threading/dispatcher_impl.h"
35 #include "threading/parallel_impl.h"
36 #include "threading/sequential_impl.h"
37 
38 #ifdef PLATFORM_HAS_JAVA
39 #include <os/java/java_internal.h>
40 #endif
41 
42 CORE_BEGIN_NAMESPACE()
43 using BASE_NS::array_view;
44 using BASE_NS::make_unique;
45 using BASE_NS::move;
46 using BASE_NS::unique_ptr;
47 
48 namespace {
49 #ifdef PLATFORM_HAS_JAVA
50 /** RAII class for handling thread setup/release. */
51 class JavaThreadContext final {
52 public:
JavaThreadContext()53     JavaThreadContext()
54     {
55         JNIEnv* env = nullptr;
56         javaVm_ = java_internal::GetJavaVM();
57 
58 #ifndef NDEBUG
59         // Check that the thread was not already attached.
60         // It's not really a problem as another attach is a no-op, but we will be detaching the
61         // thread later and it may be unexpected for the user.
62         jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
63         CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
64 #endif
65 
66         javaVm_->AttachCurrentThread(&env, nullptr);
67     }
68 
~JavaThreadContext()69     ~JavaThreadContext()
70     {
71         javaVm_->DetachCurrentThread();
72     }
73     JavaVM* javaVm_ { nullptr };
74 };
75 #endif // PLATFORM_HAS_JAVA
76 
77 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
78 class TaskResult final : public IThreadPool::IResult {
79 public:
80     // Task state which can be waited and marked as done.
81     class State {
82     public:
Done()83         void Done()
84         {
85             {
86                 auto lock = std::lock_guard(mutex_);
87                 done_ = true;
88             }
89             cv_.notify_all();
90         }
91 
Wait()92         void Wait()
93         {
94             auto lock = std::unique_lock(mutex_);
95             cv_.wait(lock, [this]() { return done_; });
96         }
97 
IsDone() const98         bool IsDone() const
99         {
100             auto lock = std::lock_guard(mutex_);
101             return done_;
102         }
103 
104     private:
105         mutable std::mutex mutex_;
106         std::condition_variable cv_;
107         bool done_ { false };
108     };
109 
TaskResult(std::shared_ptr<State> && future)110     explicit TaskResult(std::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
111 
Wait()112     void Wait() final
113     {
114         if (future_) {
115             future_->Wait();
116         }
117     }
IsDone() const118     bool IsDone() const final
119     {
120         if (future_) {
121             return future_->IsDone();
122         }
123         return true;
124     }
125 
126 protected:
Destroy()127     void Destroy() final
128     {
129         delete this;
130     }
131 
132 private:
133     std::shared_ptr<State> future_;
134 };
135 
136 // -- ThreadPool
137 class ThreadPool final : public IThreadPool {
138 public:
ThreadPool(size_t threadCount)139     explicit ThreadPool(size_t threadCount)
140         : threadCount_(threadCount), threads_(make_unique<ThreadContext[]>(threadCount))
141     {
142         CORE_ASSERT(threads_);
143 
144         // Create thread containers.
145         auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
146         for (ThreadContext& context : threads) {
147             // Set-up thread function.
148             context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
149         }
150     }
151 
152     ThreadPool(const ThreadPool&) = delete;
153     ThreadPool(ThreadPool&&) = delete;
154     ThreadPool& operator=(const ThreadPool&) = delete;
155     ThreadPool& operator=(ThreadPool&&) = delete;
156 
Push(ITask::Ptr function)157     IResult::Ptr Push(ITask::Ptr function) override
158     {
159         auto taskState = std::make_shared<TaskResult::State>();
160         if (taskState) {
161             if (function) {
162                 {
163                     std::lock_guard lock(mutex_);
164                     q_.Push(Task(move(function), taskState));
165                 }
166                 cv_.notify_one();
167             } else {
168                 // mark as done if the there was no function.
169                 taskState->Done();
170             }
171         }
172         return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
173     }
174 
PushNoWait(ITask::Ptr function)175     void PushNoWait(ITask::Ptr function) override
176     {
177         if (function) {
178             {
179                 std::lock_guard lock(mutex_);
180                 q_.Push(Task(move(function)));
181             }
182             cv_.notify_one();
183         }
184     }
185 
GetNumberOfThreads() const186     uint32_t GetNumberOfThreads() const override
187     {
188         return static_cast<uint32_t>(threadCount_);
189     }
190 
191     // IInterface
GetInterface(const BASE_NS::Uid & uid) const192     const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
193     {
194         if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
195             return this;
196         }
197         return nullptr;
198     }
199 
GetInterface(const BASE_NS::Uid & uid)200     IInterface* GetInterface(const BASE_NS::Uid& uid) override
201     {
202         if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
203             return this;
204         }
205         return nullptr;
206     }
207 
Ref()208     void Ref() override
209     {
210         refcnt_.fetch_add(1, std::memory_order_relaxed);
211     }
212 
Unref()213     void Unref() override
214     {
215         if (std::atomic_fetch_sub_explicit(&refcnt_, 1, std::memory_order_release) == 1) {
216             std::atomic_thread_fence(std::memory_order_acquire);
217             delete this;
218         }
219     }
220 
221 protected:
~ThreadPool()222     ~ThreadPool() final
223     {
224         Stop(true);
225     }
226 
227 private:
228     // Helper which holds a pointer to a queued task function and the result state.
229     struct Task {
230         ITask::Ptr function_;
231         std::shared_ptr<TaskResult::State> state_;
232 
233         ~Task() = default;
234         Task() = default;
Task__anondcf869ec0110::ThreadPool::Task235         explicit Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state)
236             : function_(move(function)), state_(CORE_NS::move(state))
237         {
238             CORE_ASSERT(this->function_ && this->state_);
239         }
Task__anondcf869ec0110::ThreadPool::Task240         explicit Task(ITask::Ptr&& function) : function_(move(function))
241         {
242             CORE_ASSERT(this->function_);
243         }
244         Task(Task&&) = default;
245         Task& operator=(Task&&) = default;
246         Task(const Task&) = delete;
247         Task& operator=(const Task&) = delete;
248 
operator ()__anondcf869ec0110::ThreadPool::Task249         void operator()() const
250         {
251             (*function_)();
252             if (state_) {
253                 state_->Done();
254             }
255         }
256     };
257 
258     template<typename T>
259     class Queue {
260     public:
Push(T && value)261         bool Push(T&& value)
262         {
263             q_.push(move(value));
264             return true;
265         }
266 
Pop(T & v)267         bool Pop(T& v)
268         {
269             if (q_.empty()) {
270                 v = {};
271                 return false;
272             }
273             v = CORE_NS::move(q_.front());
274             q_.pop();
275             return true;
276         }
277 
278     private:
279         std::queue<T> q_;
280     };
281 
282     struct ThreadContext {
283         std::thread thread;
284         bool exit { false };
285     };
286 
Clear()287     void Clear()
288     {
289         Task f;
290         std::lock_guard lock(mutex_);
291         while (q_.Pop(f)) {
292             // Intentionally empty.
293         }
294     }
295 
296     // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
297     // If this doesn't change the class can be simplified a bit.
Stop(bool waitForAllTasksToComplete)298     void Stop(bool waitForAllTasksToComplete)
299     {
300         if (isStop_) {
301             return;
302         }
303         if (waitForAllTasksToComplete) {
304             // Wait all tasks to complete before returning.
305             if (isDone_) {
306                 return;
307             }
308             std::lock_guard lock(mutex_);
309             isDone_ = true;
310         } else {
311             isStop_ = true;
312 
313             // Ask all the threads to stop and not process any more tasks.
314             auto threads = array_view(threads_.get(), threadCount_);
315             {
316                 auto lock = std::lock_guard(mutex_);
317                 for (auto& context : threads) {
318                     context.exit = true;
319                 }
320             }
321             Clear();
322         }
323 
324         // Trigger all waiting threads.
325         cv_.notify_all();
326 
327         // Wait for all threads to finish.
328         auto threads = array_view(threads_.get(), threadCount_);
329         for (auto& context : threads) {
330             if (context.thread.joinable()) {
331                 context.thread.join();
332             }
333         }
334 
335         Clear();
336     }
337 
ThreadProc(ThreadContext & context)338     void ThreadProc(ThreadContext& context)
339     {
340 #ifdef PLATFORM_HAS_JAVA
341         // RAII class for handling thread setup/release.
342         JavaThreadContext javaContext;
343 #endif
344 
345         // Get function to process.
346         Task func;
347         bool isPop = [this, &func]() {
348             std::lock_guard lock(mutex_);
349             return q_.Pop(func);
350         }();
351 
352         while (true) {
353             while (isPop) {
354                 // Run task function.
355                 func();
356 
357                 // If the thread is wanted to stop, return even if the queue is not empty yet.
358                 std::lock_guard lock(mutex_);
359                 if (context.exit) {
360                     return;
361                 }
362 
363                 // Get next function.
364                 isPop = q_.Pop(func);
365             }
366 
367             // The queue is empty here, wait for the next task.
368             std::unique_lock lock(mutex_);
369 
370             // Try to wait for next task to process.
371             cv_.wait(lock, [this, &func, &isPop, &context]() {
372                 isPop = q_.Pop(func);
373                 return isPop || isDone_ || context.exit;
374             });
375 
376             if (!isPop) {
377                 return;
378             }
379         }
380     }
381 
382     size_t threadCount_ { 0 };
383     unique_ptr<ThreadContext[]> threads_;
384 
385     Queue<Task> q_;
386     bool isDone_ { false };
387     bool isStop_ { false };
388 
389     std::mutex mutex_;
390     std::condition_variable cv_;
391     std::atomic<int32_t> refcnt_ { 0 };
392 };
393 } // namespace
394 
GetNumberOfCores() const395 uint32_t TaskQueueFactory::GetNumberOfCores() const
396 {
397     uint32_t result = std::thread::hardware_concurrency();
398     if (result == 0) {
399         // If not detectable, default to 4.
400         result = 4;
401     }
402 
403     return result;
404 }
405 
CreateThreadPool(const uint32_t threadCount) const406 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
407 {
408     return IThreadPool::Ptr { new ThreadPool(threadCount) };
409 }
410 
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const411 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
412 {
413     return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
414 }
415 
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const416 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
417 {
418     return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
419 }
420 
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const421 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
422 {
423     return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
424 }
425 
426 // IInterface
GetInterface(const BASE_NS::Uid & uid) const427 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
428 {
429     if (uid == ITaskQueueFactory::UID) {
430         return static_cast<const ITaskQueueFactory*>(this);
431     }
432     return nullptr;
433 }
434 
GetInterface(const BASE_NS::Uid & uid)435 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
436 {
437     if (uid == ITaskQueueFactory::UID) {
438         return static_cast<ITaskQueueFactory*>(this);
439     }
440     return nullptr;
441 }
442 
Ref()443 void TaskQueueFactory::Ref() {}
444 
Unref()445 void TaskQueueFactory::Unref() {}
446 CORE_END_NAMESPACE()
447