1 /*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "context_pool.h"
17
18 #include <fcntl.h>
19 #include <mutex>
20 #include <set>
21 #include <singleton.h>
22 #include <unordered_map>
23
24 #include "iam_logger.h"
25 #include "iam_para2str.h"
26 #include "iam_check.h"
27
28 #define LOG_TAG "USER_AUTH_SA"
29
30 namespace OHOS {
31 namespace UserIam {
32 namespace UserAuth {
33 namespace {
34 const uint32_t MAX_CONTEXT_NUM = 100;
GenerateRand(uint8_t * data,size_t len)35 bool GenerateRand(uint8_t *data, size_t len)
36 {
37 int fd = open("/dev/random", O_RDONLY);
38 if (fd < 0) {
39 IAM_LOGE("open read file fail");
40 return false;
41 }
42 ssize_t readLen = read(fd, data, len);
43 close(fd);
44 if (readLen < 0) {
45 IAM_LOGE("read file failed");
46 return false;
47 }
48 return static_cast<size_t>(readLen) == len;
49 }
50 }
51 class ContextPoolImpl final : public ContextPool, public Singleton<ContextPoolImpl> {
52 public:
53 bool Insert(const std::shared_ptr<Context> &context) override;
54 bool Delete(uint64_t contextId) override;
55 void CancelAll() const override;
56 std::weak_ptr<Context> Select(uint64_t contextId) const override;
57 std::vector<std::weak_ptr<Context>> Select(ContextType contextType) const override;
58 std::shared_ptr<ScheduleNode> SelectScheduleNodeByScheduleId(uint64_t scheduleId) override;
59 bool RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
60 bool DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
61
62 private:
63 void CheckPreemptContext(const std::shared_ptr<Context> &context);
64 mutable std::recursive_mutex poolMutex_;
65 std::unordered_map<uint64_t, std::shared_ptr<Context>> contextMap_;
66 std::set<std::shared_ptr<ContextPoolListener>> listenerSet_;
67 };
68
CheckPreemptContext(const std::shared_ptr<Context> & context)69 void ContextPoolImpl::CheckPreemptContext(const std::shared_ptr<Context> &context)
70 {
71 if (context->GetContextType() != ContextType::CONTEXT_SIMPLE_AUTH) {
72 return;
73 }
74 for (auto iter = contextMap_.begin(); iter != contextMap_.end(); iter++) {
75 if (iter->second == nullptr) {
76 IAM_LOGE("context is nullptr");
77 break;
78 }
79 if (iter->second->GetCallerName() == context->GetCallerName() &&
80 iter->second->GetAuthType() == context->GetAuthType() &&
81 iter->second->GetUserId() == context->GetUserId()) {
82 IAM_LOGE("contextId:%{public}hx is preempted, newContextId:%{public}hx, mapSize:%{public}zu,"
83 "callerName:%{public}s, userId:%{public}d, authType:%{public}d", static_cast<uint16_t>(iter->first),
84 static_cast<uint16_t>(context->GetContextId()), contextMap_.size(), context->GetCallerName().c_str(),
85 context->GetUserId(), context->GetAuthType());
86 iter->second->Stop();
87 break;
88 }
89 }
90 }
91
Insert(const std::shared_ptr<Context> & context)92 bool ContextPoolImpl::Insert(const std::shared_ptr<Context> &context)
93 {
94 if (context == nullptr) {
95 IAM_LOGE("context is nullptr");
96 return false;
97 }
98 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
99 if (contextMap_.size() >= MAX_CONTEXT_NUM) {
100 IAM_LOGE("context pool is full");
101 return false;
102 }
103 CheckPreemptContext(context);
104 uint64_t contextId = context->GetContextId();
105 auto result = contextMap_.try_emplace(contextId, context);
106 if (!result.second) {
107 return false;
108 }
109 for (const auto &listener : listenerSet_) {
110 if (listener != nullptr) {
111 listener->OnContextPoolInsert(context);
112 }
113 }
114 return true;
115 }
116
Delete(uint64_t contextId)117 bool ContextPoolImpl::Delete(uint64_t contextId)
118 {
119 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
120 auto iter = contextMap_.find(contextId);
121 if (iter == contextMap_.end()) {
122 IAM_LOGE("context not found");
123 return false;
124 }
125 auto tempContext = iter->second;
126 contextMap_.erase(iter);
127 for (const auto &listener : listenerSet_) {
128 if (listener != nullptr) {
129 listener->OnContextPoolDelete(tempContext);
130 }
131 }
132 return true;
133 }
134
CancelAll() const135 void ContextPoolImpl::CancelAll() const
136 {
137 IAM_LOGI("start");
138 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
139 for (const auto &context : contextMap_) {
140 if (context.second == nullptr) {
141 continue;
142 }
143 IAM_LOGI("cancel context %{public}s", GET_MASKED_STRING(context.second->GetContextId()).c_str());
144 if (!context.second->Stop()) {
145 IAM_LOGE("cancel context %{public}s fail", GET_MASKED_STRING(context.second->GetContextId()).c_str());
146 }
147 }
148 }
149
Select(uint64_t contextId) const150 std::weak_ptr<Context> ContextPoolImpl::Select(uint64_t contextId) const
151 {
152 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
153 std::weak_ptr<Context> result;
154 auto iter = contextMap_.find(contextId);
155 if (iter != contextMap_.end()) {
156 result = iter->second;
157 }
158 return result;
159 }
160
Select(ContextType contextType) const161 std::vector<std::weak_ptr<Context>> ContextPoolImpl::Select(ContextType contextType) const
162 {
163 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
164 std::vector<std::weak_ptr<Context>> result;
165 for (const auto &context : contextMap_) {
166 if (context.second == nullptr) {
167 continue;
168 }
169 if (context.second->GetContextType() == contextType) {
170 result.emplace_back(context.second);
171 }
172 }
173 return result;
174 }
175
SelectScheduleNodeByScheduleId(uint64_t scheduleId)176 std::shared_ptr<ScheduleNode> ContextPoolImpl::SelectScheduleNodeByScheduleId(uint64_t scheduleId)
177 {
178 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
179 for (const auto &context : contextMap_) {
180 if (context.second == nullptr) {
181 continue;
182 }
183 auto node = context.second->GetScheduleNode(scheduleId);
184 if (node != nullptr) {
185 return node;
186 }
187 }
188
189 IAM_LOGE("not found");
190 return nullptr;
191 }
192
RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)193 bool ContextPoolImpl::RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
194 {
195 if (listener == nullptr) {
196 IAM_LOGE("listener is nullptr");
197 return false;
198 }
199 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
200 listenerSet_.insert(listener);
201 return true;
202 }
203
DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)204 bool ContextPoolImpl::DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
205 {
206 std::lock_guard<std::recursive_mutex> lock(poolMutex_);
207 return listenerSet_.erase(listener) == 1;
208 }
209
Instance()210 ContextPool &ContextPool::Instance()
211 {
212 return ContextPoolImpl::GetInstance();
213 }
214
GetNewContextId()215 uint64_t ContextPool::GetNewContextId()
216 {
217 static constexpr uint32_t MAX_TRY_TIMES = 10;
218 static std::mutex mutex;
219 std::lock_guard<std::mutex> lock(mutex);
220 uint64_t contextId = 0;
221 unsigned char *contextIdPtr = static_cast<unsigned char *>(static_cast<void *>(&contextId));
222 for (uint32_t i = 0; i < MAX_TRY_TIMES; i++) {
223 bool genRandRet = GenerateRand(contextIdPtr, sizeof(uint64_t));
224 if (!genRandRet) {
225 IAM_LOGE("generate rand fail");
226 return 0;
227 }
228 if (contextId == 0 || contextId == REUSE_AUTH_RESULT_CONTEXT_ID ||
229 ContextPool::Instance().Select(contextId).lock() != nullptr) {
230 IAM_LOGE("invalid or duplicate context id");
231 continue;
232 }
233 break;
234 }
235 return contextId;
236 }
237 } // namespace UserAuth
238 } // namespace UserIam
239 } // namespace OHOS
240