1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "class_registry.h"
17 
18 #include <base/util/uid_util.h>
19 
20 #include "meta/base/interface_utils.h"
21 
META_BEGIN_NAMESPACE()22 META_BEGIN_NAMESPACE()
23 
24 void ClassRegistry::Clear()
25 {
26     std::unique_lock lock { mutex_ };
27     objectFactories_.clear();
28 }
29 
Unregister(const IObjectFactory::Ptr & fac)30 bool ClassRegistry::Unregister(const IObjectFactory::Ptr& fac)
31 {
32     if (!fac) {
33         CORE_LOG_E("ClassRegistry: Cannot unregister a null object factory");
34         return false;
35     }
36     size_t erased = 0;
37     {
38         std::unique_lock lock { mutex_ };
39         erased = objectFactories_.erase(fac->GetClassInfo());
40     }
41     if (erased) {
42         onUnregistered_->Invoke({ fac });
43         return true;
44     }
45     return false;
46 }
47 
Register(const IObjectFactory::Ptr & fac)48 bool ClassRegistry::Register(const IObjectFactory::Ptr& fac)
49 {
50     if (!fac) {
51         CORE_LOG_E("ClassRegistry: Cannot register a null object factory");
52         return false;
53     }
54     {
55         std::unique_lock lock { mutex_ };
56         auto& info = fac->GetClassInfo();
57         auto& i = objectFactories_[info];
58         if (i) {
59             CORE_LOG_W("ClassRegistry: Cannot register a class that was already registered [name=%s, uid=%s]",
60                 info.Name().data(), info.Id().ToString().c_str());
61             return false;
62         }
63         i = fac;
64     }
65     onRegistered_->Invoke({ fac });
66     return true;
67 }
68 
GetObjectFactory(const BASE_NS::Uid & uid) const69 IObjectFactory::ConstPtr ClassRegistry::GetObjectFactory(const BASE_NS::Uid& uid) const
70 {
71     std::shared_lock lock { mutex_ };
72     auto it = objectFactories_.find(uid);
73     return it != objectFactories_.end() ? it->second : nullptr;
74 }
75 
GetClassName(BASE_NS::Uid uid) const76 BASE_NS::string ClassRegistry::GetClassName(BASE_NS::Uid uid) const
77 {
78     std::shared_lock lock { mutex_ };
79     auto it = objectFactories_.find(uid);
80     return it != objectFactories_.end() ? BASE_NS::string(it->second->GetClassInfo().Name())
81                                         : BASE_NS::string("Unknown class id [") + BASE_NS::to_string(uid) + "]";
82 }
83 
GetAllTypes(ObjectCategoryBits category,bool strict,bool excludeDeprecated) const84 BASE_NS::vector<IClassInfo::ConstPtr> ClassRegistry::GetAllTypes(
85     ObjectCategoryBits category, bool strict, bool excludeDeprecated) const
86 {
87     std::shared_lock lock { mutex_ };
88     BASE_NS::vector<IClassInfo::ConstPtr> infos;
89     for (auto&& v : objectFactories_) {
90         const auto& factory = v.second;
91         if (excludeDeprecated && (factory->GetClassInfo().category & ObjectCategoryBits::DEPRECATED)) {
92             // Omit DEPRECATED classes if excludeDeprecated flag is true
93             continue;
94         }
95         if (CheckCategoryBits(factory->GetClassInfo().category, category, strict)) {
96             infos.emplace_back(factory);
97         }
98     }
99     return infos;
100 }
101 
GetAllTypes(const BASE_NS::vector<BASE_NS::Uid> & interfaceUids,bool strict,bool excludeDeprecated) const102 BASE_NS::vector<IClassInfo::ConstPtr> ClassRegistry::GetAllTypes(
103     const BASE_NS::vector<BASE_NS::Uid>& interfaceUids, bool strict, bool excludeDeprecated) const
104 {
105     std::shared_lock lock { mutex_ };
106     BASE_NS::vector<IClassInfo::ConstPtr> infos;
107     for (auto&& v : objectFactories_) {
108         const auto& factory = v.second;
109         if (factory->GetClassInfo().category & ObjectCategoryBits::INTERNAL) {
110             // Omit classes with INTERNAL flag from the list
111             continue;
112         }
113         if (excludeDeprecated && (factory->GetClassInfo().category & ObjectCategoryBits::DEPRECATED)) {
114             // Omit DEPRECATED classes if excludeDeprecated flag is true
115             continue;
116         }
117         if (CheckInterfaces(factory->GetClassInterfaces(), interfaceUids, strict)) {
118             infos.push_back(factory);
119         }
120     }
121     return infos;
122 }
123 
124 META_END_NAMESPACE()
125