1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "param_check_utils.h"
17 
18 #include "cloud/cloud_db_constant.h"
19 #include "cloud/cloud_storage_utils.h"
20 #include "db_common.h"
21 #include "db_constant.h"
22 #include "db_errno.h"
23 #include "log_print.h"
24 #include "platform_specific.h"
25 
26 namespace DistributedDB {
CheckDataDir(const std::string & dataDir,std::string & canonicalDir)27 bool ParamCheckUtils::CheckDataDir(const std::string &dataDir, std::string &canonicalDir)
28 {
29     if (dataDir.empty() || (dataDir.length() > DBConstant::MAX_DATA_DIR_LENGTH)) {
30         LOGE("Invalid data directory[%zu]", dataDir.length());
31         return false;
32     }
33 
34     // After normalizing the path, determine whether the path is a legal path considered by the program.
35     // There has been guaranteed by the upper layer, So there is no need trustlist set here.
36     return (OS::GetRealPath(dataDir, canonicalDir) == E_OK);
37 }
38 
IsStoreIdSafe(const std::string & storeId,bool allowStoreIdWithDot)39 bool ParamCheckUtils::IsStoreIdSafe(const std::string &storeId, bool allowStoreIdWithDot)
40 {
41     if (storeId.empty() || (storeId.length() > DBConstant::MAX_STORE_ID_LENGTH)) {
42         LOGE("Invalid store id[%zu]", storeId.length());
43         return false;
44     }
45 
46     auto iter = std::find_if_not(storeId.begin(), storeId.end(),
47         [allowStoreIdWithDot](char value) {
48         return (std::isalnum(value) || value == '_') || (allowStoreIdWithDot && value == '.');
49     });
50     if (iter != storeId.end()) {
51         LOGE("Invalid store id format");
52         return false;
53     }
54     return true;
55 }
56 
CheckStoreParameter(const std::string & storeId,const std::string & appId,const std::string & userId,bool isIgnoreUserIdCheck,const std::string & subUser)57 bool ParamCheckUtils::CheckStoreParameter(const std::string &storeId, const std::string &appId,
58     const std::string &userId, bool isIgnoreUserIdCheck, const std::string &subUser)
59 {
60     return CheckStoreParameter({userId, appId, storeId}, isIgnoreUserIdCheck, subUser);
61 }
62 
CheckStoreParameter(const StoreInfo & info,bool isIgnoreUserIdCheck,const std::string & subUser,bool allowStoreIdWithDot)63 bool ParamCheckUtils::CheckStoreParameter(const StoreInfo &info, bool isIgnoreUserIdCheck,
64     const std::string &subUser, bool allowStoreIdWithDot)
65 {
66     const auto &storeId = info.storeId;
67     const auto &userId = info.userId;
68     const auto &appId = info.appId;
69     if (!IsStoreIdSafe(storeId, allowStoreIdWithDot)) {
70         return false;
71     }
72     if (!isIgnoreUserIdCheck) {
73         if (userId.empty() || userId.length() > DBConstant::MAX_USER_ID_LENGTH) {
74             LOGE("Invalid user info[%zu][%zu]", userId.length(), appId.length());
75             return false;
76         }
77         if (userId.find(DBConstant::ID_CONNECTOR) != std::string::npos) {
78             LOGE("Invalid userId character in the store para info.");
79             return false;
80         }
81     }
82     if (appId.empty() || appId.length() > DBConstant::MAX_APP_ID_LENGTH) {
83         LOGE("Invalid app info[%zu][%zu]", userId.length(), appId.length());
84         return false;
85     }
86     // subUser allow empty
87     if (subUser.length() > DBConstant::MAX_SUB_USER_LENGTH) {
88         LOGE("Invalid subUser info[%zu][%zu]", userId.length(), subUser.length());
89         return false;
90     }
91 
92     if ((appId.find(DBConstant::ID_CONNECTOR) != std::string::npos) ||
93         (storeId.find(DBConstant::ID_CONNECTOR) != std::string::npos) ||
94         (subUser.find(DBConstant::ID_CONNECTOR) != std::string::npos)) {
95         LOGE("Invalid character in the store para info.");
96         return false;
97     }
98     return true;
99 }
100 
CheckEncryptedParameter(CipherType cipher,const CipherPassword & passwd)101 bool ParamCheckUtils::CheckEncryptedParameter(CipherType cipher, const CipherPassword &passwd)
102 {
103     if (cipher != CipherType::DEFAULT && cipher != CipherType::AES_256_GCM) {
104         LOGE("Invalid cipher type!");
105         return false;
106     }
107 
108     return (passwd.GetSize() != 0);
109 }
110 
CheckConflictNotifierType(int conflictType)111 bool ParamCheckUtils::CheckConflictNotifierType(int conflictType)
112 {
113     if (conflictType <= 0) {
114         return false;
115     }
116     // Divide the type into different types.
117     if (conflictType >= CONFLICT_NATIVE_ALL) {
118         conflictType -= CONFLICT_NATIVE_ALL;
119     }
120     if (conflictType >= CONFLICT_FOREIGN_KEY_ORIG) {
121         conflictType -= CONFLICT_FOREIGN_KEY_ORIG;
122     }
123     if (conflictType >= CONFLICT_FOREIGN_KEY_ONLY) {
124         conflictType -= CONFLICT_FOREIGN_KEY_ONLY;
125     }
126     return (conflictType == 0);
127 }
128 
CheckSecOption(const SecurityOption & secOption)129 bool ParamCheckUtils::CheckSecOption(const SecurityOption &secOption)
130 {
131     if (secOption.securityLabel > S4 || secOption.securityLabel < NOT_SET) {
132         LOGE("[DBCommon] SecurityLabel is invalid, label is [%d].", secOption.securityLabel);
133         return false;
134     }
135     if (secOption.securityFlag != 0) {
136         if ((secOption.securityLabel != S3 && secOption.securityLabel != S4) || secOption.securityFlag != SECE) {
137             LOGE("[DBCommon] SecurityFlag is invalid.");
138             return false;
139         }
140     }
141     return true;
142 }
143 
CheckObserver(const Key & key,unsigned int mode)144 bool ParamCheckUtils::CheckObserver(const Key &key, unsigned int mode)
145 {
146     if (key.size() > DBConstant::MAX_KEY_SIZE) {
147         return false;
148     }
149     uint64_t rawMode = DBCommon::EraseBit(mode, DBConstant::OBSERVER_CHANGES_MASK);
150     if (rawMode == OBSERVER_CHANGES_NATIVE || rawMode == OBSERVER_CHANGES_FOREIGN ||
151         rawMode == OBSERVER_CHANGES_LOCAL_ONLY || rawMode == OBSERVER_CHANGES_CLOUD ||
152         rawMode == (OBSERVER_CHANGES_NATIVE | OBSERVER_CHANGES_FOREIGN)) {
153             return true;
154     }
155     return false;
156 }
157 
IsS3SECEOpt(const SecurityOption & secOpt)158 bool ParamCheckUtils::IsS3SECEOpt(const SecurityOption &secOpt)
159 {
160     SecurityOption S3SeceOpt = {SecurityLabel::S3, SecurityFlag::SECE};
161     return (secOpt == S3SeceOpt);
162 }
163 
CheckAndTransferAutoLaunchParam(const AutoLaunchParam & param,bool checkDir,SchemaObject & schemaObject,std::string & canonicalDir)164 int ParamCheckUtils::CheckAndTransferAutoLaunchParam(const AutoLaunchParam &param, bool checkDir,
165     SchemaObject &schemaObject, std::string &canonicalDir)
166 {
167     if ((param.option.notifier && !ParamCheckUtils::CheckConflictNotifierType(param.option.conflictType)) ||
168         (!param.option.notifier && param.option.conflictType != 0)) {
169         LOGE("[AutoLaunch] CheckConflictNotifierType is invalid.");
170         return -E_INVALID_ARGS;
171     }
172     if (!ParamCheckUtils::CheckStoreParameter(param.storeId, param.appId, param.userId)) {
173         LOGE("[AutoLaunch] CheckStoreParameter is invalid.");
174         return -E_INVALID_ARGS;
175     }
176 
177     const AutoLaunchOption &option = param.option;
178     if (!ParamCheckUtils::CheckSecOption(option.secOption)) {
179         LOGE("[AutoLaunch] CheckSecOption is invalid.");
180         return -E_INVALID_ARGS;
181     }
182 
183     if (option.isEncryptedDb) {
184         if (!ParamCheckUtils::CheckEncryptedParameter(option.cipher, option.passwd)) {
185             LOGE("[AutoLaunch] CheckEncryptedParameter is invalid.");
186             return -E_INVALID_ARGS;
187         }
188     }
189 
190     if (!param.option.schema.empty()) {
191         schemaObject.ParseFromSchemaString(param.option.schema);
192         if (!schemaObject.IsSchemaValid()) {
193             LOGE("[AutoLaunch] ParseFromSchemaString is invalid.");
194             return -E_INVALID_SCHEMA;
195         }
196     }
197 
198     if (!checkDir) {
199         canonicalDir = param.option.dataDir;
200         return E_OK;
201     }
202 
203     if (!ParamCheckUtils::CheckDataDir(param.option.dataDir, canonicalDir)) {
204         LOGE("[AutoLaunch] CheckDataDir is invalid.");
205         return -E_INVALID_ARGS;
206     }
207     return E_OK;
208 }
209 
GetValidCompressionRate(uint8_t compressionRate)210 uint8_t ParamCheckUtils::GetValidCompressionRate(uint8_t compressionRate)
211 {
212     // Valid when between 1 and 100. When compressionRate is invalid, change it to default rate.
213     if (compressionRate < 1 || compressionRate > DBConstant::DEFAULT_COMPTRESS_RATE) {
214         LOGD("Invalid compression rate:%" PRIu8, compressionRate);
215         compressionRate = DBConstant::DEFAULT_COMPTRESS_RATE;
216     }
217     return compressionRate;
218 }
219 
CheckRelationalTableName(const std::string & tableName)220 bool ParamCheckUtils::CheckRelationalTableName(const std::string &tableName)
221 {
222     if (!DBCommon::CheckIsAlnumOrUnderscore(tableName)) {
223         return false;
224     }
225     return tableName.compare(0, DBConstant::SYSTEM_TABLE_PREFIX.size(), DBConstant::SYSTEM_TABLE_PREFIX) != 0;
226 }
227 
CheckTableReference(const std::vector<TableReferenceProperty> & tableReferenceProperty)228 bool ParamCheckUtils::CheckTableReference(const std::vector<TableReferenceProperty> &tableReferenceProperty)
229 {
230     if (tableReferenceProperty.empty()) {
231         LOGI("[CheckTableReference] tableReferenceProperty is empty");
232         return true;
233     }
234 
235     std::vector<std::vector<int>> dependency;
236     std::map<std::string, int, CaseInsensitiveComparator> tableName2Int;
237     int index = 0;
238     for (const auto &item : tableReferenceProperty) {
239         if (item.sourceTableName.empty() || item.targetTableName.empty() || item.columns.empty()) {
240             LOGE("[CheckTableReference] table name or column is empty");
241             return false;
242         }
243         std::vector<int> vec;
244         for (const auto &tableName : { item.sourceTableName, item.targetTableName }) {
245             if (tableName2Int.find(tableName) != tableName2Int.end()) {
246                 vec.push_back(tableName2Int.at(tableName));
247             } else {
248                 vec.push_back(index);
249                 tableName2Int[tableName] = index;
250                 index++;
251             }
252         }
253         if (std::find(dependency.begin(), dependency.end(), vec) != dependency.end()) {
254             LOGE("[CheckTableReference] set multiple reference for two tables is not support.");
255             return false;
256         }
257         dependency.emplace_back(vec);
258     }
259 
260     if (DBCommon::IsCircularDependency(index, dependency)) {
261         LOGE("[CheckTableReference] circular reference is not support.");
262         return false;
263     }
264     return true;
265 }
266 
CheckSharedTableName(const DataBaseSchema & schema)267 bool ParamCheckUtils::CheckSharedTableName(const DataBaseSchema &schema)
268 {
269     DataBaseSchema lowerSchema = schema;
270     TransferSchemaToLower(lowerSchema);
271     std::set<std::string> tableNames;
272     std::set<std::string> sharedTableNames;
273     for (const auto &tableSchema : lowerSchema.tables) {
274         if (tableSchema.sharedTableName.empty()) {
275             continue;
276         }
277         if (tableSchema.sharedTableName == tableSchema.name) {
278             LOGE("[CheckSharedTableName] Shared table name and table name are same.");
279             return false;
280         }
281         if (sharedTableNames.find(tableSchema.sharedTableName) != sharedTableNames.end() ||
282             sharedTableNames.find(tableSchema.name) != sharedTableNames.end() ||
283             tableNames.find(tableSchema.sharedTableName) != tableNames.end() ||
284             tableNames.find(tableSchema.name) != tableNames.end()) {
285             LOGE("[CheckSharedTableName] Shared table names or table names are duplicate.");
286             return false;
287         }
288         if (!CheckRelationalTableName(tableSchema.sharedTableName)) {
289             return false;
290         }
291         tableNames.insert(tableSchema.name);
292         sharedTableNames.insert(tableSchema.sharedTableName);
293         std::set<std::string> fields;
294         for (const auto &field : tableSchema.fields) {
295             if (fields.find(field.colName) != fields.end() || field.colName == CloudDbConstant::CLOUD_OWNER ||
296                 field.colName == CloudDbConstant::CLOUD_PRIVILEGE) {
297                 LOGE("[CheckSharedTableName] fields are duplicate.");
298                 return false;
299             }
300             fields.insert(field.colName);
301         }
302     }
303     return true;
304 }
305 
TransferSchemaToLower(DataBaseSchema & schema)306 void ParamCheckUtils::TransferSchemaToLower(DataBaseSchema &schema)
307 {
308     for (auto &tableSchema : schema.tables) {
309         std::transform(tableSchema.name.begin(), tableSchema.name.end(), tableSchema.name.begin(), ::tolower);
310         std::transform(tableSchema.sharedTableName.begin(), tableSchema.sharedTableName.end(),
311             tableSchema.sharedTableName.begin(), ::tolower);
312         for (auto &field : tableSchema.fields) {
313             std::transform(field.colName.begin(), field.colName.end(), field.colName.begin(), ::tolower);
314         }
315     }
316 }
317 } // namespace DistributedDB