1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "context_pool.h"
17 
18 #include <fcntl.h>
19 #include <mutex>
20 #include <set>
21 #include <singleton.h>
22 #include <unordered_map>
23 
24 #include "iam_logger.h"
25 #include "iam_para2str.h"
26 #include "iam_check.h"
27 
28 #define LOG_TAG "USER_AUTH_SA"
29 
30 namespace OHOS {
31 namespace UserIam {
32 namespace UserAuth {
33 namespace {
34 const uint32_t MAX_CONTEXT_NUM = 100;
GenerateRand(uint8_t * data,size_t len)35 bool GenerateRand(uint8_t *data, size_t len)
36 {
37     int fd = open("/dev/random", O_RDONLY);
38     if (fd < 0) {
39         IAM_LOGE("open read file fail");
40         return false;
41     }
42     ssize_t readLen = read(fd, data, len);
43     close(fd);
44     if (readLen < 0) {
45         IAM_LOGE("read file failed");
46         return false;
47     }
48     return static_cast<size_t>(readLen) == len;
49 }
50 }
51 class ContextPoolImpl final : public ContextPool, public Singleton<ContextPoolImpl> {
52 public:
53     bool Insert(const std::shared_ptr<Context> &context) override;
54     bool Delete(uint64_t contextId) override;
55     void CancelAll() const override;
56     std::weak_ptr<Context> Select(uint64_t contextId) const override;
57     std::vector<std::weak_ptr<Context>> Select(ContextType contextType) const override;
58     std::shared_ptr<ScheduleNode> SelectScheduleNodeByScheduleId(uint64_t scheduleId) override;
59     bool RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
60     bool DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener) override;
61 
62 private:
63     void CheckPreemptContext(const std::shared_ptr<Context> &context);
64     mutable std::recursive_mutex poolMutex_;
65     std::unordered_map<uint64_t, std::shared_ptr<Context>> contextMap_;
66     std::set<std::shared_ptr<ContextPoolListener>> listenerSet_;
67 };
68 
CheckPreemptContext(const std::shared_ptr<Context> & context)69 void ContextPoolImpl::CheckPreemptContext(const std::shared_ptr<Context> &context)
70 {
71     if (context->GetContextType() != ContextType::CONTEXT_SIMPLE_AUTH) {
72         return;
73     }
74     for (auto iter = contextMap_.begin(); iter != contextMap_.end(); iter++) {
75         if (iter->second == nullptr) {
76             IAM_LOGE("context is nullptr");
77             break;
78         }
79         if (iter->second->GetCallerName() == context->GetCallerName() &&
80             iter->second->GetAuthType() == context->GetAuthType() &&
81             iter->second->GetUserId() == context->GetUserId()) {
82             IAM_LOGE("contextId:%{public}hx is preempted, newContextId:%{public}hx, mapSize:%{public}zu,"
83                 "callerName:%{public}s, userId:%{public}d, authType:%{public}d", static_cast<uint16_t>(iter->first),
84                 static_cast<uint16_t>(context->GetContextId()), contextMap_.size(), context->GetCallerName().c_str(),
85                 context->GetUserId(), context->GetAuthType());
86             iter->second->Stop();
87             break;
88         }
89     }
90 }
91 
Insert(const std::shared_ptr<Context> & context)92 bool ContextPoolImpl::Insert(const std::shared_ptr<Context> &context)
93 {
94     if (context == nullptr) {
95         IAM_LOGE("context is nullptr");
96         return false;
97     }
98     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
99     if (contextMap_.size() >= MAX_CONTEXT_NUM) {
100         IAM_LOGE("context pool is full");
101         return false;
102     }
103     CheckPreemptContext(context);
104     uint64_t contextId = context->GetContextId();
105     auto result = contextMap_.try_emplace(contextId, context);
106     if (!result.second) {
107         return false;
108     }
109     for (const auto &listener : listenerSet_) {
110         if (listener != nullptr) {
111             listener->OnContextPoolInsert(context);
112         }
113     }
114     return true;
115 }
116 
Delete(uint64_t contextId)117 bool ContextPoolImpl::Delete(uint64_t contextId)
118 {
119     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
120     auto iter = contextMap_.find(contextId);
121     if (iter == contextMap_.end()) {
122         IAM_LOGE("context not found");
123         return false;
124     }
125     auto tempContext = iter->second;
126     contextMap_.erase(iter);
127     for (const auto &listener : listenerSet_) {
128         if (listener != nullptr) {
129             listener->OnContextPoolDelete(tempContext);
130         }
131     }
132     return true;
133 }
134 
CancelAll() const135 void ContextPoolImpl::CancelAll() const
136 {
137     IAM_LOGI("start");
138     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
139     for (const auto &context : contextMap_) {
140         if (context.second == nullptr) {
141             continue;
142         }
143         IAM_LOGI("cancel context %{public}s", GET_MASKED_STRING(context.second->GetContextId()).c_str());
144         if (!context.second->Stop()) {
145             IAM_LOGE("cancel context %{public}s fail", GET_MASKED_STRING(context.second->GetContextId()).c_str());
146         }
147     }
148 }
149 
Select(uint64_t contextId) const150 std::weak_ptr<Context> ContextPoolImpl::Select(uint64_t contextId) const
151 {
152     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
153     std::weak_ptr<Context> result;
154     auto iter = contextMap_.find(contextId);
155     if (iter != contextMap_.end()) {
156         result = iter->second;
157     }
158     return result;
159 }
160 
Select(ContextType contextType) const161 std::vector<std::weak_ptr<Context>> ContextPoolImpl::Select(ContextType contextType) const
162 {
163     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
164     std::vector<std::weak_ptr<Context>> result;
165     for (const auto &context : contextMap_) {
166         if (context.second == nullptr) {
167             continue;
168         }
169         if (context.second->GetContextType() == contextType) {
170             result.emplace_back(context.second);
171         }
172     }
173     return result;
174 }
175 
SelectScheduleNodeByScheduleId(uint64_t scheduleId)176 std::shared_ptr<ScheduleNode> ContextPoolImpl::SelectScheduleNodeByScheduleId(uint64_t scheduleId)
177 {
178     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
179     for (const auto &context : contextMap_) {
180         if (context.second == nullptr) {
181             continue;
182         }
183         auto node = context.second->GetScheduleNode(scheduleId);
184         if (node != nullptr) {
185             return node;
186         }
187     }
188 
189     IAM_LOGE("not found");
190     return nullptr;
191 }
192 
RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)193 bool ContextPoolImpl::RegisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
194 {
195     if (listener == nullptr) {
196         IAM_LOGE("listener is nullptr");
197         return false;
198     }
199     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
200     listenerSet_.insert(listener);
201     return true;
202 }
203 
DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> & listener)204 bool ContextPoolImpl::DeregisterContextPoolListener(const std::shared_ptr<ContextPoolListener> &listener)
205 {
206     std::lock_guard<std::recursive_mutex> lock(poolMutex_);
207     return listenerSet_.erase(listener) == 1;
208 }
209 
Instance()210 ContextPool &ContextPool::Instance()
211 {
212     return ContextPoolImpl::GetInstance();
213 }
214 
GetNewContextId()215 uint64_t ContextPool::GetNewContextId()
216 {
217     static constexpr uint32_t MAX_TRY_TIMES = 10;
218     static std::mutex mutex;
219     std::lock_guard<std::mutex> lock(mutex);
220     uint64_t contextId = 0;
221     unsigned char *contextIdPtr = static_cast<unsigned char *>(static_cast<void *>(&contextId));
222     for (uint32_t i = 0; i < MAX_TRY_TIMES; i++) {
223         bool genRandRet = GenerateRand(contextIdPtr, sizeof(uint64_t));
224         if (!genRandRet) {
225             IAM_LOGE("generate rand fail");
226             return 0;
227         }
228         if (contextId == 0 || contextId == REUSE_AUTH_RESULT_CONTEXT_ID ||
229             ContextPool::Instance().Select(contextId).lock() != nullptr) {
230             IAM_LOGE("invalid or duplicate context id");
231             continue;
232         }
233         break;
234     }
235     return contextId;
236 }
237 } // namespace UserAuth
238 } // namespace UserIam
239 } // namespace OHOS
240