1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef META_BASE_SHARED_PTR_H
17 #define META_BASE_SHARED_PTR_H
18 
19 #include <meta/base/meta_types.h>
20 
21 #include "shared_ptr_internals.h"
22 
BASE_BEGIN_NAMESPACE()23 BASE_BEGIN_NAMESPACE()
24 
25 /**
26  * @brief C++ standard like weak_ptr.
27  */
28 template<typename T>
29 class weak_ptr final : public Internals::PtrCountedBase<T> {
30 public:
31     using element_type = BASE_NS::remove_extent_t<T>;
32     weak_ptr() = default;
33     ~weak_ptr()
34     {
35         if (this->control_) {
36             this->control_->ReleaseWeak();
37         }
38     };
39 
40     weak_ptr(nullptr_t) {}
41     weak_ptr(const shared_ptr<T>& p) : Internals::PtrCountedBase<T>(p)
42     {
43         if (this->control_) {
44             this->control_->AddWeak();
45         }
46     }
47     weak_ptr(const weak_ptr& p) noexcept : Internals::PtrCountedBase<T>(p)
48     {
49         if (this->control_) {
50             this->control_->AddWeak();
51         }
52     }
53     weak_ptr(weak_ptr&& p) noexcept : Internals::PtrCountedBase<T>(p)
54     {
55         p.InternalReset();
56     }
57 
58     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
59     shared_ptr<T> lock() const
60     {
61         return shared_ptr<T>(*this);
62     }
63     weak_ptr& operator=(weak_ptr&& p) noexcept
64     {
65         if (this != &p) {
66             reset();
67             this->control_ = p.control_;
68             this->pointer_ = p.pointer_;
69             p.InternalReset();
70         }
71         return *this;
72     }
73     weak_ptr& operator=(const weak_ptr& p) noexcept
74     {
75         if (this != &p) {
76             reset();
77             this->control_ = p.control_;
78             this->pointer_ = p.pointer_;
79             if (this->control_) {
80                 this->control_->AddWeak();
81             }
82         }
83         return *this;
84     }
85     weak_ptr& operator=(const shared_ptr<T>& p)
86     {
87         reset();
88         this->control_ = p.control_;
89         this->pointer_ = p.pointer_;
90         if (this->control_) {
91             this->control_->AddWeak();
92         }
93         return *this;
94     }
95 
96     weak_ptr& operator=(nullptr_t) noexcept
97     {
98         reset();
99         return *this;
100     }
101     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
102     void reset()
103     {
104         if (this->control_) {
105             this->control_->ReleaseWeak();
106             this->InternalReset();
107         }
108     }
109 
110     /*"implicit" casting constructors */
111     template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
112     weak_ptr(const shared_ptr<U>& p)
113         // handle casting by using functionality in shared_ptr. (creates an aliased shared_ptr to original.)
114         : weak_ptr(shared_ptr<T>(p))
115     {}
116 
117     template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
118     weak_ptr(const weak_ptr<U>& p) : weak_ptr(shared_ptr<T>(p.lock()))
119     {}
120 
121     /* "implicit" casting move */
122     template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
123     weak_ptr(weak_ptr<U>&& p) noexcept : weak_ptr(shared_ptr<T>(p.lock()))
124     {
125         p.reset();
126     }
127 
128     /* "implicit" casting operators */
129     template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
130     weak_ptr& operator=(const shared_ptr<U>& p)
131     {
132         // handle casting by using functionality in shared_ptr. (creates an aliased shared_ptr to original.)
133         *this = shared_ptr<T>(p);
134         return *this;
135     }
136 
137     template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
138     weak_ptr& operator=(const weak_ptr<U>& p)
139     {
140         // first lock the given weak ptr. (to see if it has expired, and to get a pointer that can be cast)
141         *this = shared_ptr<T>(p.lock());
142         return *this;
143     }
144 
145     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
146     bool expired() const noexcept
147     {
148         return !this->control_ || this->control_->GetStrongCount() == 0;
149     }
150 
151 private:
152     friend class shared_ptr<T>;
153     template<typename>
154     friend class weak_ptr;
155 };
156 
157 /**
158  * @brief C++ standard like shared_ptr with IInterface support for reference counting.
159  */
160 template<typename T>
161 class shared_ptr final : public Internals::PtrCountedBase<T> {
162 public:
163     using element_type = BASE_NS::remove_extent_t<T>;
164     using weak_type = weak_ptr<T>;
165 
166     constexpr shared_ptr() noexcept = default;
shared_ptr(nullptr_t)167     constexpr shared_ptr(nullptr_t) noexcept {}
shared_ptr(const shared_ptr & p)168     shared_ptr(const shared_ptr& p) noexcept : Internals::PtrCountedBase<T>(p)
169     {
170         if (this->control_) {
171             this->control_->AddStrongCopy();
172         }
173     }
shared_ptr(shared_ptr && p)174     shared_ptr(shared_ptr&& p) noexcept : Internals::PtrCountedBase<T>(p)
175     {
176         p.InternalReset();
177     }
shared_ptr(T * ptr)178     explicit shared_ptr(T* ptr)
179     {
180         if (ptr) {
181             ConstructBlock(ptr);
182         }
183     }
184 
185     template<typename Deleter>
shared_ptr(T * ptr,Deleter deleter)186     shared_ptr(T* ptr, Deleter deleter)
187     {
188         if (ptr) {
189             ConstructBlock(ptr, BASE_NS::move(deleter));
190         }
191     }
192 
shared_ptr(const weak_type & p)193     explicit shared_ptr(const weak_type& p) noexcept : Internals::PtrCountedBase<T>(p)
194     {
195         if (this->control_) {
196             if (!this->control_->AddStrongLock()) {
197                 this->InternalReset();
198             }
199         }
200     }
201     template<class Y>
shared_ptr(const shared_ptr<Y> & r,T * ptr)202     shared_ptr(const shared_ptr<Y>& r, T* ptr) noexcept : Internals::PtrCountedBase<T>(r.control_)
203     {
204         if (this->control_ && ptr) {
205             this->control_->AddStrongCopy();
206             this->pointer_ = const_cast<deletableType*>(ptr);
207         } else {
208             this->InternalReset();
209         }
210     }
211     template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
shared_ptr(shared_ptr<U> && p)212     shared_ptr(shared_ptr<U>&& p) noexcept : Internals::PtrCountedBase<T>(p.control_)
213     {
214         if (this->control_) {
215             void* ptr = nullptr;
216             if constexpr (BASE_NS::is_same_v<T, const BASE_NS::remove_const_t<U>> ||
217                           !META_NS::HasGetInterfaceMethod_v<U>) {
218                 ptr = p.get();
219             } else {
220                 // make a proper interface cast here.
221                 if constexpr (BASE_NS::is_const_v<T>) {
222                     ptr = const_cast<void*>(static_cast<const void*>(p->GetInterface(T::UID)));
223                 } else {
224                     ptr = static_cast<void*>(p->GetInterface(T::UID));
225                 }
226             }
227             if (ptr) {
228                 this->pointer_ = static_cast<deletableType*>(ptr);
229                 p.InternalReset();
230             } else {
231                 this->InternalReset();
232                 p.reset();
233             }
234         }
235     }
236     template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
shared_ptr(const shared_ptr<U> & p)237     shared_ptr(const shared_ptr<U>& p) noexcept : shared_ptr(shared_ptr<U>(p)) // use the above move constructor
238     {}
239 
240     template<class U, class D, class = Internals::EnableIfPointerConvertible<U, T>>
shared_ptr(unique_ptr<U,D> && p)241     shared_ptr(unique_ptr<U, D>&& p) noexcept
242     {
243         if (p) {
244             ConstructBlock(p.release(), BASE_NS::move(p.get_deleter()));
245         }
246     }
247 
~shared_ptr()248     ~shared_ptr()
249     {
250         if (this->control_) {
251             this->control_->Release();
252         }
253     }
254     T* operator->() const noexcept
255     {
256         return get();
257     }
258     T& operator*() const noexcept
259     {
260         return *get();
261     }
262     explicit operator bool() const
263     {
264         return get();
265     }
266     bool operator==(const shared_ptr& other) const noexcept
267     {
268         return get() == other.get();
269     }
270     bool operator!=(const shared_ptr& other) const noexcept
271     {
272         return !(*this == other);
273     }
274     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
reset()275     void reset()
276     {
277         if (this->control_) {
278             this->control_->Release();
279             this->InternalReset();
280         }
281     }
282     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
reset(T * ptr)283     void reset(T* ptr)
284     {
285         if (ptr != this->pointer_) {
286             reset();
287             if (ptr) {
288                 ConstructBlock(ptr);
289             }
290         }
291     }
292     template<typename Deleter>
293     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
reset(T * ptr,Deleter deleter)294     void reset(T* ptr, Deleter deleter)
295     {
296         if (ptr != this->pointer_) {
297             reset();
298             if (ptr) {
299                 ConstructBlock(ptr, BASE_NS::move(deleter));
300             }
301         }
302     }
303 
304     shared_ptr& operator=(nullptr_t) noexcept
305     {
306         reset();
307         return *this;
308     }
309     shared_ptr& operator=(const shared_ptr& o)
310     {
311         if (this != &o) {
312             reset();
313             this->control_ = o.control_;
314             this->pointer_ = o.pointer_;
315             if (this->control_) {
316                 this->control_->AddStrongCopy();
317             }
318         }
319         return *this;
320     }
321     shared_ptr& operator=(shared_ptr&& o) noexcept
322     {
323         if (this != &o) {
324             reset();
325             this->control_ = o.control_;
326             this->pointer_ = o.pointer_;
327             o.InternalReset();
328         }
329         return *this;
330     }
331     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
swap(shared_ptr & p)332     void swap(shared_ptr& p)
333     {
334         auto tp = p.pointer_;
335         auto tc = p.control_;
336         p.pointer_ = this->pointer_;
337         p.control_ = this->control_;
338         this->pointer_ = tp;
339         this->control_ = tc;
340     }
341     // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
get()342     element_type* get() const noexcept
343     {
344         return this->pointer_;
345     }
346 
347 private:
348     using deletableType = BASE_NS::remove_const_t<T>;
349 
ConstructBlock(T * ptr)350     void ConstructBlock(T* ptr)
351     {
352         static_assert(sizeof(T), "type has to be complete when constructing control block");
353         if constexpr (BASE_NS::is_convertible_v<deletableType*, CORE_NS::IInterface*>) {
354             this->control_ = new Internals::RefCountedObjectStorageBlock(ptr);
355         } else {
356             this->control_ = new Internals::StorageBlock(ptr);
357         }
358         this->pointer_ = ptr;
359     }
360     template<typename Deleter>
ConstructBlock(T * ptr,Deleter deleter)361     void ConstructBlock(T* ptr, Deleter deleter)
362     {
363         this->control_ = new Internals::StorageBlockWithDeleter(ptr, BASE_NS::move(deleter));
364         this->pointer_ = ptr;
365     }
366 
367     template<typename>
368     friend class weak_ptr;
369     template<typename>
370     friend class shared_ptr;
371 };
372 
BASE_END_NAMESPACE()373 BASE_END_NAMESPACE()
374 
375 // NOLINTBEGIN(readability-identifier-naming) to keep std like syntax
376 template<class U, class T>
377 BASE_NS::shared_ptr<U> static_pointer_cast(const BASE_NS::shared_ptr<T>& ptr)
378 {
379     if (ptr) {
380         return BASE_NS::shared_ptr<U>(ptr, static_cast<U*>(ptr.get()));
381     }
382     return {};
383 }
384 
385 template<class U, class T>
interface_pointer_cast(const BASE_NS::shared_ptr<T> & ptr)386 BASE_NS::shared_ptr<U> interface_pointer_cast(const BASE_NS::shared_ptr<T>& ptr)
387 {
388     static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
389     static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
390     if (ptr) {
391         if constexpr (BASE_NS::is_same_v<U, T>) {
392             // same type.
393             return ptr;
394         } else {
395             return BASE_NS::shared_ptr<U>(ptr, static_cast<U*>(static_cast<void*>(ptr->GetInterface(U::UID))));
396         }
397     }
398     return {};
399 }
400 
401 template<class U, class T>
interface_pointer_cast(const BASE_NS::shared_ptr<const T> & ptr)402 BASE_NS::shared_ptr<const U> interface_pointer_cast(const BASE_NS::shared_ptr<const T>& ptr)
403 {
404     static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
405     static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
406     if (ptr) {
407         if constexpr (BASE_NS::is_same_v<U, T>) {
408             // same type.
409             return ptr;
410         } else {
411             return BASE_NS::shared_ptr<const U>(
412                 ptr, static_cast<const U*>(static_cast<const void*>(ptr->GetInterface(U::UID))));
413         }
414     }
415     return {};
416 }
417 
418 template<class U, class T>
interface_pointer_cast(const BASE_NS::weak_ptr<T> & weak)419 BASE_NS::shared_ptr<U> interface_pointer_cast(const BASE_NS::weak_ptr<T>& weak)
420 {
421     return interface_pointer_cast<U>(weak.lock());
422 }
423 
424 template<class U, class T>
interface_pointer_cast(const BASE_NS::weak_ptr<const T> & weak)425 BASE_NS::shared_ptr<const U> interface_pointer_cast(const BASE_NS::weak_ptr<const T>& weak)
426 {
427     return interface_pointer_cast<const U>(weak.lock());
428 }
429 
430 template<class U, class T>
interface_cast(const BASE_NS::shared_ptr<T> & ptr)431 U* interface_cast(const BASE_NS::shared_ptr<T>& ptr)
432 {
433     static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
434     static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
435     if (ptr) {
436         if constexpr (BASE_NS::is_same_v<U, T>) {
437             // same type.
438             return ptr.get();
439         } else {
440             return static_cast<U*>(static_cast<void*>(ptr->GetInterface(U::UID)));
441         }
442     }
443     return {};
444 }
445 
446 template<class U, class T>
interface_cast(const BASE_NS::shared_ptr<const T> & ptr)447 const U* interface_cast(const BASE_NS::shared_ptr<const T>& ptr)
448 {
449     static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
450     static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
451     if (ptr) {
452         if constexpr (BASE_NS::is_same_v<U, T>) {
453             // same type.
454             return ptr.get();
455         } else {
456             return static_cast<const U*>(static_cast<const void*>(ptr->GetInterface(U::UID)));
457         }
458     }
459     return {};
460 }
461 
462 template<class U>
interface_cast(CORE_NS::IInterface * ptr)463 U* interface_cast(CORE_NS::IInterface* ptr)
464 {
465     static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
466     if (ptr) {
467         return static_cast<U*>(static_cast<void*>(ptr->GetInterface(U::UID)));
468     }
469     return {};
470 }
471 
472 template<class U>
interface_cast(const CORE_NS::IInterface * ptr)473 const U* interface_cast(const CORE_NS::IInterface* ptr)
474 {
475     static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
476     if (ptr) {
477         return static_cast<const U*>(static_cast<const void*>(ptr->GetInterface(U::UID)));
478     }
479     return {};
480 }
481 // NOLINTEND(readability-identifier-naming) to keep std like syntax
482 
483 template<typename T, typename... Args>
CreateShared(Args &&...args)484 BASE_NS::shared_ptr<T> CreateShared(Args&&... args)
485 {
486     return BASE_NS::shared_ptr<T>(new T(BASE_NS::forward<Args>(args)...));
487 }
488 
489 META_TYPE(BASE_NS::shared_ptr<const CORE_NS::IInterface>)
490 META_TYPE(BASE_NS::shared_ptr<CORE_NS::IInterface>)
491 META_TYPE(BASE_NS::weak_ptr<const CORE_NS::IInterface>)
492 META_TYPE(BASE_NS::weak_ptr<CORE_NS::IInterface>)
493 
494 #endif
495