1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "access_token_db.h"
17 
18 #include <algorithm>
19 #include <cinttypes>
20 #include <mutex>
21 
22 #include "accesstoken_log.h"
23 #include "access_token_error.h"
24 #include "access_token_open_callback.h"
25 #include "rdb_helper.h"
26 #include "time_util.h"
27 #include "token_field_const.h"
28 
29 namespace OHOS {
30 namespace Security {
31 namespace AccessToken {
32 namespace {
33 static constexpr OHOS::HiviewDFX::HiLogLabel LABEL = {LOG_CORE, SECURITY_DOMAIN_ACCESSTOKEN, "AccessTokenDb"};
34 
35 constexpr const char* DATABASE_NAME = "access_token.db";
36 constexpr const char* ACCESSTOKEN_SERVICE_NAME = "accesstoken_service";
37 std::recursive_mutex g_instanceMutex;
38 }
39 
GetInstance()40 AccessTokenDb& AccessTokenDb::GetInstance()
41 {
42     static AccessTokenDb* instance = nullptr;
43     if (instance == nullptr) {
44         std::lock_guard<std::recursive_mutex> lock(g_instanceMutex);
45         if (instance == nullptr) {
46             instance = new AccessTokenDb();
47         }
48     }
49     return *instance;
50 }
51 
AccessTokenDb()52 AccessTokenDb::AccessTokenDb()
53 {
54     std::string dbPath = std::string(DATABASE_PATH) + std::string(DATABASE_NAME);
55     NativeRdb::RdbStoreConfig config(dbPath);
56     config.SetSecurityLevel(NativeRdb::SecurityLevel::S3);
57     config.SetAllowRebuild(true);
58     config.SetHaMode(NativeRdb::HAMode::MAIN_REPLICA); // Real-time dual-write backup database
59     config.SetServiceName(std::string(ACCESSTOKEN_SERVICE_NAME));
60     AccessTokenOpenCallback callback;
61     int32_t res = NativeRdb::E_OK;
62     // pragma user_version will done by rdb, they store path and db_ as pair in RdbStoreManager
63     db_ = NativeRdb::RdbHelper::GetRdbStore(config, DATABASE_VERSION_4, callback, res);
64     if ((res != NativeRdb::E_OK) || (db_ == nullptr)) {
65         ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to init rdb, res is %{public}d.", res);
66         return;
67     }
68 }
69 
RestoreAndInsertIfCorrupt(const int32_t resultCode,int64_t & outInsertNum,const std::string & tableName,const std::vector<NativeRdb::ValuesBucket> & buckets)70 int32_t AccessTokenDb::RestoreAndInsertIfCorrupt(const int32_t resultCode, int64_t& outInsertNum,
71     const std::string& tableName, const std::vector<NativeRdb::ValuesBucket>& buckets)
72 {
73     if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
74         return resultCode;
75     }
76 
77     ACCESSTOKEN_LOG_WARN(LABEL, "Detech database corrupt, restore from backup!");
78     int32_t res = db_->Restore("");
79     if (res != NativeRdb::E_OK) {
80         ACCESSTOKEN_LOG_ERROR(LABEL, "Db restore failed, res is %{public}d.", res);
81         return res;
82     }
83     ACCESSTOKEN_LOG_INFO(LABEL, "Database restore success, try insert again!");
84 
85     res = db_->BatchInsert(outInsertNum, tableName, buckets);
86     if (res != NativeRdb::E_OK) {
87         ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to batch insert into table %{public}s again, res is %{public}d.",
88             tableName.c_str(), res);
89         return res;
90     }
91 
92     return 0;
93 }
94 
Add(const AtmDataType type,const std::vector<GenericValues> & values)95 int32_t AccessTokenDb::Add(const AtmDataType type, const std::vector<GenericValues>& values)
96 {
97     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
98     ACCESSTOKEN_LOG_INFO(LABEL, "Add type is %{public}u.", type);
99 
100     std::string tableName;
101     AccessTokenDbUtil::GetTableNameByType(type, tableName);
102     if (tableName.empty()) {
103         return AccessTokenError::ERR_PARAM_INVALID;
104     }
105 
106     size_t addSize = values.size();
107     if (addSize == 0) {
108         ACCESSTOKEN_LOG_INFO(LABEL, "Insert values is empty.");
109         return 0;
110     }
111 
112     std::vector<NativeRdb::ValuesBucket> buckets;
113     AccessTokenDbUtil::ToRdbValueBuckets(values, buckets);
114 
115     int64_t outInsertNum = 0;
116     {
117         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
118         if (db_ == nullptr) {
119             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
120         }
121 
122         int32_t res = db_->BatchInsert(outInsertNum, tableName, buckets);
123         if (res != NativeRdb::E_OK) {
124             ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to batch insert into table %{public}s, res is %{public}d.",
125                 tableName.c_str(), res);
126             int32_t result = RestoreAndInsertIfCorrupt(res, outInsertNum, tableName, buckets);
127             if (result != NativeRdb::E_OK) {
128                 return result;
129             }
130         }
131     }
132 
133     if (outInsertNum <= 0) { // this is rdb bug, adapt it
134         ACCESSTOKEN_LOG_ERROR(LABEL, "Insert count %{public}" PRId64 " abnormal.", outInsertNum);
135         return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
136     }
137 
138     int64_t endTime = TimeUtil::GetCurrentTimestamp();
139     ACCESSTOKEN_LOG_INFO(LABEL, "Add call cast %{public}" PRId64 ", batch insert %{public}" PRId64
140         " records to table %{public}s.", endTime - beginTime, outInsertNum, tableName.c_str());
141 
142     return 0;
143 }
144 
RestoreAndDeleteIfCorrupt(const int32_t resultCode,int32_t & deletedRows,const NativeRdb::RdbPredicates & predicates)145 int32_t AccessTokenDb::RestoreAndDeleteIfCorrupt(const int32_t resultCode, int32_t& deletedRows,
146     const NativeRdb::RdbPredicates& predicates)
147 {
148     if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
149         return resultCode;
150     }
151 
152     ACCESSTOKEN_LOG_WARN(LABEL, "Detech database corrupt, restore from backup!");
153     int32_t res = db_->Restore("");
154     if (res != NativeRdb::E_OK) {
155         ACCESSTOKEN_LOG_ERROR(LABEL, "Db restore failed, res is %{public}d.", res);
156         return res;
157     }
158     ACCESSTOKEN_LOG_INFO(LABEL, "Database restore success, try delete again!");
159 
160     res = db_->Delete(deletedRows, predicates);
161     if (res != NativeRdb::E_OK) {
162         ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to delete record from table %{public}s again, res is %{public}d.",
163             predicates.GetTableName().c_str(), res);
164         return res;
165     }
166 
167     return 0;
168 }
169 
Remove(const AtmDataType type,const GenericValues & conditionValue)170 int32_t AccessTokenDb::Remove(const AtmDataType type, const GenericValues& conditionValue)
171 {
172     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
173     ACCESSTOKEN_LOG_INFO(LABEL, "Remove type is %{public}u.", type);
174 
175     std::string tableName;
176     AccessTokenDbUtil::GetTableNameByType(type, tableName);
177     if (tableName.empty()) {
178         return AccessTokenError::ERR_PARAM_INVALID;
179     }
180 
181     NativeRdb::RdbPredicates predicates(tableName);
182     AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
183 
184     int32_t deletedRows = 0;
185     {
186         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
187         if (db_ == nullptr) {
188             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
189         }
190 
191         int32_t res = db_->Delete(deletedRows, predicates);
192         if (res != NativeRdb::E_OK) {
193             ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to delete record from table %{public}s, res is %{public}d.",
194                 tableName.c_str(), res);
195             int32_t result = RestoreAndDeleteIfCorrupt(res, deletedRows, predicates);
196             if (result != NativeRdb::E_OK) {
197                 return result;
198             }
199         }
200     }
201 
202     int64_t endTime = TimeUtil::GetCurrentTimestamp();
203     ACCESSTOKEN_LOG_INFO(LABEL, "Remove call cast %{public}" PRId64
204         ", delete %{public}d records from table %{public}s.", endTime - beginTime, deletedRows, tableName.c_str());
205 
206     return 0;
207 }
208 
RestoreAndUpdateIfCorrupt(const int32_t resultCode,int32_t & changedRows,const NativeRdb::ValuesBucket & bucket,const NativeRdb::RdbPredicates & predicates)209 int32_t AccessTokenDb::RestoreAndUpdateIfCorrupt(const int32_t resultCode, int32_t& changedRows,
210     const NativeRdb::ValuesBucket& bucket, const NativeRdb::RdbPredicates& predicates)
211 {
212     if (resultCode != NativeRdb::E_SQLITE_CORRUPT) {
213         return resultCode;
214     }
215 
216     ACCESSTOKEN_LOG_WARN(LABEL, "Detech database corrupt, restore from backup!");
217     int32_t res = db_->Restore("");
218     if (res != NativeRdb::E_OK) {
219         ACCESSTOKEN_LOG_ERROR(LABEL, "Db restore failed, res is %{public}d.", res);
220         return res;
221     }
222     ACCESSTOKEN_LOG_INFO(LABEL, "Database restore success, try update again!");
223 
224     res = db_->Update(changedRows, bucket, predicates);
225     if (res != NativeRdb::E_OK) {
226         ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to update record from table %{public}s again, res is %{public}d.",
227             predicates.GetTableName().c_str(), res);
228         return res;
229     }
230 
231     return 0;
232 }
233 
Modify(const AtmDataType type,const GenericValues & modifyValue,const GenericValues & conditionValue)234 int32_t AccessTokenDb::Modify(const AtmDataType type, const GenericValues& modifyValue,
235     const GenericValues& conditionValue)
236 {
237     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
238     ACCESSTOKEN_LOG_INFO(LABEL, "Modify type is %{public}u.", type);
239 
240     std::string tableName;
241     AccessTokenDbUtil::GetTableNameByType(type, tableName);
242     if (tableName.empty()) {
243         return AccessTokenError::ERR_PARAM_INVALID;
244     }
245 
246     NativeRdb::ValuesBucket bucket;
247 
248     AccessTokenDbUtil::ToRdbValueBucket(modifyValue, bucket);
249     if (bucket.IsEmpty()) {
250         return AccessTokenError::ERR_PARAM_INVALID;
251     }
252 
253     NativeRdb::RdbPredicates predicates(tableName);
254     AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
255 
256     int32_t changedRows = 0;
257     {
258         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
259         if (db_ == nullptr) {
260             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
261         }
262 
263         int32_t res = db_->Update(changedRows, bucket, predicates);
264         if (res != NativeRdb::E_OK) {
265             ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to update record from table %{public}s, res is %{public}d.",
266                 tableName.c_str(), res);
267             int32_t result = RestoreAndUpdateIfCorrupt(res, changedRows, bucket, predicates);
268             if (result != NativeRdb::E_OK) {
269                 return result;
270             }
271         }
272     }
273 
274     int64_t endTime = TimeUtil::GetCurrentTimestamp();
275     ACCESSTOKEN_LOG_INFO(LABEL, "Modify call cast %{public}" PRId64
276         ", update %{public}d records from table %{public}s.", endTime - beginTime, changedRows, tableName.c_str());
277 
278     return 0;
279 }
280 
RestoreAndQueryIfCorrupt(const NativeRdb::RdbPredicates & predicates,const std::vector<std::string> & columns,std::shared_ptr<NativeRdb::AbsSharedResultSet> & queryResultSet)281 int32_t AccessTokenDb::RestoreAndQueryIfCorrupt(const NativeRdb::RdbPredicates& predicates,
282     const std::vector<std::string>& columns, std::shared_ptr<NativeRdb::AbsSharedResultSet>& queryResultSet)
283 {
284     int32_t count = 0;
285     int32_t res = queryResultSet->GetRowCount(count);
286     if (res != NativeRdb::E_OK) {
287         if (res == NativeRdb::E_SQLITE_CORRUPT) {
288             queryResultSet->Close();
289             queryResultSet = nullptr;
290 
291             ACCESSTOKEN_LOG_WARN(LABEL, "Detech database corrupt, restore from backup!");
292             int32_t res = db_->Restore("");
293             if (res != NativeRdb::E_OK) {
294                 ACCESSTOKEN_LOG_ERROR(LABEL, "Db restore failed, res is %{public}d.", res);
295                 return res;
296             }
297             ACCESSTOKEN_LOG_INFO(LABEL, "Database restore success, try query again!");
298 
299             queryResultSet = db_->Query(predicates, columns);
300             if (queryResultSet == nullptr) {
301                 ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to find records from table %{public}s again.",
302                     predicates.GetTableName().c_str());
303                 return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
304             }
305         } else {
306             ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to get result count.");
307             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
308         }
309     }
310 
311     return 0;
312 }
313 
Find(AtmDataType type,const GenericValues & conditionValue,std::vector<GenericValues> & results)314 int32_t AccessTokenDb::Find(AtmDataType type, const GenericValues& conditionValue,
315     std::vector<GenericValues>& results)
316 {
317     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
318     ACCESSTOKEN_LOG_INFO(LABEL, "Find type is %{public}u.", type);
319 
320     std::string tableName;
321     AccessTokenDbUtil::GetTableNameByType(type, tableName);
322     if (tableName.empty()) {
323         return AccessTokenError::ERR_PARAM_INVALID;
324     }
325 
326     NativeRdb::RdbPredicates predicates(tableName);
327     AccessTokenDbUtil::ToRdbPredicates(conditionValue, predicates);
328 
329     std::vector<std::string> columns; // empty columns means query all columns
330     int count = 0;
331     {
332         OHOS::Utils::UniqueReadGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
333         if (db_ == nullptr) {
334             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
335         }
336 
337         auto queryResultSet = db_->Query(predicates, columns);
338         if (queryResultSet == nullptr) {
339             ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to find records from table %{public}s.",
340                 tableName.c_str());
341             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
342         }
343 
344         int32_t res = RestoreAndQueryIfCorrupt(predicates, columns, queryResultSet);
345         if (res != 0) {
346             return res;
347         }
348 
349         while (queryResultSet->GoToNextRow() == NativeRdb::E_OK) {
350             GenericValues value;
351             AccessTokenDbUtil::ResultToGenericValues(queryResultSet, value);
352             if (value.GetAllKeys().empty()) {
353                 continue;
354             }
355 
356             results.emplace_back(value);
357             count++;
358         }
359     }
360 
361     int64_t endTime = TimeUtil::GetCurrentTimestamp();
362     ACCESSTOKEN_LOG_INFO(LABEL, "Find call cast %{public}" PRId64
363         ", query %{public}d records from table %{public}s.", endTime - beginTime, count, tableName.c_str());
364 
365     return 0;
366 }
367 
DeleteAndAddSingleTable(const GenericValues delCondition,const std::string & tableName,const std::vector<GenericValues> & addValues)368 int32_t AccessTokenDb::DeleteAndAddSingleTable(const GenericValues delCondition, const std::string& tableName,
369     const std::vector<GenericValues>& addValues)
370 {
371     NativeRdb::RdbPredicates predicates(tableName);
372     AccessTokenDbUtil::ToRdbPredicates(delCondition, predicates); // fill predicates with delCondition
373     int32_t deletedRows = 0;
374     int32_t res = db_->Delete(deletedRows, predicates);
375     if (res != NativeRdb::E_OK) {
376         ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to delete record from table %{public}s, res is %{public}d.",
377             tableName.c_str(), res);
378         int32_t result = RestoreAndDeleteIfCorrupt(res, deletedRows, predicates);
379         if (result != NativeRdb::E_OK) {
380             return result;
381         }
382     }
383     ACCESSTOKEN_LOG_INFO(LABEL, "Delete %{public}d record from table %{public}s", deletedRows, tableName.c_str());
384 
385     // if nothing to insert, no need to call BatchInsert
386     if (addValues.empty()) {
387         return 0;
388     }
389 
390     std::vector<NativeRdb::ValuesBucket> buckets;
391     AccessTokenDbUtil::ToRdbValueBuckets(addValues, buckets); // fill buckets with addValues
392     int64_t outInsertNum = 0;
393     res = db_->BatchInsert(outInsertNum, tableName, buckets);
394     if (res != NativeRdb::E_OK) {
395         ACCESSTOKEN_LOG_ERROR(LABEL, "Failed to batch insert into table %{public}s, res is %{public}d.",
396             tableName.c_str(), res);
397         int32_t result = RestoreAndInsertIfCorrupt(res, outInsertNum, tableName, buckets);
398         if (result != NativeRdb::E_OK) {
399             return result;
400         }
401     }
402     if (outInsertNum <= 0) { // rdb bug, adapt it
403         ACCESSTOKEN_LOG_ERROR(LABEL, "Insert count %{public}" PRId64 " abnormal.", outInsertNum);
404         return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
405     }
406     ACCESSTOKEN_LOG_INFO(LABEL, "Batch insert %{public}" PRId64 " records to table %{public}s.", outInsertNum,
407         tableName.c_str());
408 
409     return 0;
410 }
411 
DeleteAndAddRecord(AccessTokenID tokenId,const std::vector<GenericValues> & hapInfoValues,const std::vector<GenericValues> & permDefValues,const std::vector<GenericValues> & permStateValues)412 int32_t AccessTokenDb::DeleteAndAddRecord(AccessTokenID tokenId, const std::vector<GenericValues>& hapInfoValues,
413     const std::vector<GenericValues>& permDefValues, const std::vector<GenericValues>& permStateValues)
414 {
415     GenericValues conditionValue;
416     conditionValue.Put(TokenFiledConst::FIELD_TOKEN_ID, static_cast<int32_t>(tokenId));
417 
418     std::string hapTableName;
419     AccessTokenDbUtil::GetTableNameByType(AtmDataType::ACCESSTOKEN_HAP_INFO, hapTableName);
420     int32_t res = DeleteAndAddSingleTable(conditionValue, hapTableName, hapInfoValues);
421     if (res != NativeRdb::E_OK) {
422         return res;
423     }
424 
425     std::string defTableName;
426     AccessTokenDbUtil::GetTableNameByType(AtmDataType::ACCESSTOKEN_PERMISSION_DEF, defTableName);
427     res = DeleteAndAddSingleTable(conditionValue, defTableName, permDefValues);
428     if (res != NativeRdb::E_OK) {
429         return res;
430     }
431 
432     std::string stateTableName;
433     AccessTokenDbUtil::GetTableNameByType(AtmDataType::ACCESSTOKEN_PERMISSION_STATE, stateTableName);
434     return DeleteAndAddSingleTable(conditionValue, stateTableName, permStateValues);
435 }
436 
DeleteAndInsertHap(AccessTokenID tokenId,const std::vector<GenericValues> & hapInfoValues,const std::vector<GenericValues> & permDefValues,const std::vector<GenericValues> & permStateValues)437 int32_t AccessTokenDb::DeleteAndInsertHap(AccessTokenID tokenId, const std::vector<GenericValues>& hapInfoValues,
438     const std::vector<GenericValues>& permDefValues, const std::vector<GenericValues>& permStateValues)
439 {
440     int64_t beginTime = TimeUtil::GetCurrentTimestamp();
441 
442     {
443         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> lock(this->rwLock_);
444         if (db_ == nullptr) {
445             return AccessTokenError::ERR_DATABASE_OPERATE_FAILED;
446         }
447 
448         db_->BeginTransaction();
449 
450         int32_t res = DeleteAndAddRecord(tokenId, hapInfoValues, permDefValues, permStateValues);
451         if (res != NativeRdb::E_OK) {
452             db_->RollBack();
453             return res;
454         }
455 
456         db_->Commit();
457     }
458 
459     int64_t endTime = TimeUtil::GetCurrentTimestamp();
460     ACCESSTOKEN_LOG_INFO(LABEL, "DeleteAndInsertNative cost %{public}" PRId64 ".", endTime - beginTime);
461 
462     return 0;
463 }
464 } // namespace AccessToken
465 } // namespace Security
466 } // namespace OHOS
467