1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "sqlite_relational_utils.h"
17 #include "db_errno.h"
18 #include "cloud/cloud_db_types.h"
19 #include "sqlite_utils.h"
20 #include "cloud/cloud_storage_utils.h"
21 #include "runtime_context.h"
22 #include "cloud/cloud_db_constant.h"
23 
24 namespace DistributedDB {
GetDataValueByType(sqlite3_stmt * statement,int cid,DataValue & value)25 int SQLiteRelationalUtils::GetDataValueByType(sqlite3_stmt *statement, int cid, DataValue &value)
26 {
27     if (statement == nullptr || cid < 0 || cid >= sqlite3_column_count(statement)) {
28         return -E_INVALID_ARGS;
29     }
30 
31     int errCode = E_OK;
32     int storageType = sqlite3_column_type(statement, cid);
33     switch (storageType) {
34         case SQLITE_INTEGER:
35             value = static_cast<int64_t>(sqlite3_column_int64(statement, cid));
36             break;
37         case SQLITE_FLOAT:
38             value = sqlite3_column_double(statement, cid);
39             break;
40         case SQLITE_BLOB: {
41             std::vector<uint8_t> blobValue;
42             errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
43             if (errCode != E_OK) {
44                 return errCode;
45             }
46             auto blob = new (std::nothrow) Blob;
47             if (blob == nullptr) {
48                 return -E_OUT_OF_MEMORY;
49             }
50             blob->WriteBlob(blobValue.data(), static_cast<uint32_t>(blobValue.size()));
51             errCode = value.Set(blob);
52             if (errCode != E_OK) {
53                 delete blob;
54                 blob = nullptr;
55             }
56             break;
57         }
58         case SQLITE_NULL:
59             break;
60         case SQLITE3_TEXT: {
61             std::string str;
62             (void)SQLiteUtils::GetColumnTextValue(statement, cid, str);
63             value = str;
64             if (value.GetType() != StorageType::STORAGE_TYPE_TEXT) {
65                 errCode = -E_OUT_OF_MEMORY;
66             }
67             break;
68         }
69         default:
70             break;
71     }
72     return errCode;
73 }
74 
GetSelectValues(sqlite3_stmt * stmt)75 std::vector<DataValue> SQLiteRelationalUtils::GetSelectValues(sqlite3_stmt *stmt)
76 {
77     std::vector<DataValue> values;
78     for (int cid = 0, colCount = sqlite3_column_count(stmt); cid < colCount; ++cid) {
79         DataValue value;
80         (void)GetDataValueByType(stmt, cid, value);
81         values.emplace_back(std::move(value));
82     }
83     return values;
84 }
85 
GetCloudValueByType(sqlite3_stmt * statement,int type,int cid,Type & cloudValue)86 int SQLiteRelationalUtils::GetCloudValueByType(sqlite3_stmt *statement, int type, int cid, Type &cloudValue)
87 {
88     if (statement == nullptr || cid < 0 || cid >= sqlite3_column_count(statement)) {
89         return -E_INVALID_ARGS;
90     }
91     switch (sqlite3_column_type(statement, cid)) {
92         case SQLITE_INTEGER: {
93             if (type == TYPE_INDEX<bool>) {
94                 cloudValue = static_cast<bool>(sqlite3_column_int(statement, cid));
95                 break;
96             }
97             cloudValue = static_cast<int64_t>(sqlite3_column_int64(statement, cid));
98             break;
99         }
100         case SQLITE_FLOAT: {
101             cloudValue = sqlite3_column_double(statement, cid);
102             break;
103         }
104         case SQLITE_BLOB: {
105             std::vector<uint8_t> blobValue;
106             int errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
107             if (errCode != E_OK) {
108                 return errCode;
109             }
110             cloudValue = blobValue;
111             break;
112         }
113         case SQLITE3_TEXT: {
114             bool isBlob = (type == TYPE_INDEX<Bytes> || type == TYPE_INDEX<Asset> || type == TYPE_INDEX<Assets>);
115             if (isBlob) {
116                 std::vector<uint8_t> blobValue;
117                 int errCode = SQLiteUtils::GetColumnBlobValue(statement, cid, blobValue);
118                 if (errCode != E_OK) {
119                     return errCode;
120                 }
121                 cloudValue = blobValue;
122                 break;
123             }
124             std::string str;
125             (void)SQLiteUtils::GetColumnTextValue(statement, cid, str);
126             cloudValue = str;
127             break;
128         }
129         default: {
130             cloudValue = Nil();
131         }
132     }
133     return E_OK;
134 }
135 
CalCloudValueLen(Type & cloudValue,uint32_t & totalSize)136 void SQLiteRelationalUtils::CalCloudValueLen(Type &cloudValue, uint32_t &totalSize)
137 {
138     switch (cloudValue.index()) {
139         case TYPE_INDEX<int64_t>:
140             totalSize += sizeof(int64_t);
141             break;
142         case TYPE_INDEX<double>:
143             totalSize += sizeof(double);
144             break;
145         case TYPE_INDEX<std::string>:
146             totalSize += std::get<std::string>(cloudValue).size();
147             break;
148         case TYPE_INDEX<bool>:
149             totalSize += sizeof(int32_t);
150             break;
151         case TYPE_INDEX<Bytes>:
152         case TYPE_INDEX<Asset>:
153         case TYPE_INDEX<Assets>:
154             totalSize += std::get<Bytes>(cloudValue).size();
155             break;
156         default: {
157             break;
158         }
159     }
160 }
161 
BindStatementByType(sqlite3_stmt * statement,int cid,Type & typeVal)162 int SQLiteRelationalUtils::BindStatementByType(sqlite3_stmt *statement, int cid, Type &typeVal)
163 {
164     int errCode = E_OK;
165     switch (typeVal.index()) {
166         case TYPE_INDEX<int64_t>: {
167             int64_t value = 0;
168             (void)CloudStorageUtils::GetValueFromType(typeVal, value);
169             errCode = SQLiteUtils::BindInt64ToStatement(statement, cid, value);
170             break;
171         }
172         case TYPE_INDEX<bool>: {
173             bool value = false;
174             (void)CloudStorageUtils::GetValueFromType<bool>(typeVal, value);
175             errCode = SQLiteUtils::BindInt64ToStatement(statement, cid, value);
176             break;
177         }
178         case TYPE_INDEX<double>: {
179             double value = 0.0;
180             (void)CloudStorageUtils::GetValueFromType<double>(typeVal, value);
181             errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_double(statement, cid, value));
182             break;
183         }
184         case TYPE_INDEX<std::string>: {
185             std::string value;
186             (void)CloudStorageUtils::GetValueFromType<std::string>(typeVal, value);
187             errCode = SQLiteUtils::BindTextToStatement(statement, cid, value);
188             break;
189         }
190         default: {
191             errCode = BindExtendStatementByType(statement, cid, typeVal);
192             break;
193         }
194     }
195     return errCode;
196 }
197 
BindExtendStatementByType(sqlite3_stmt * statement,int cid,Type & typeVal)198 int SQLiteRelationalUtils::BindExtendStatementByType(sqlite3_stmt *statement, int cid, Type &typeVal)
199 {
200     int errCode = E_OK;
201     switch (typeVal.index()) {
202         case TYPE_INDEX<Bytes>: {
203             Bytes value;
204             (void)CloudStorageUtils::GetValueFromType<Bytes>(typeVal, value);
205             errCode = SQLiteUtils::BindBlobToStatement(statement, cid, value);
206             break;
207         }
208         case TYPE_INDEX<Asset>: {
209             Asset value;
210             (void)CloudStorageUtils::GetValueFromType<Asset>(typeVal, value);
211             Bytes val;
212             errCode = RuntimeContext::GetInstance()->AssetToBlob(value, val);
213             if (errCode != E_OK) {
214                 break;
215             }
216             errCode = SQLiteUtils::BindBlobToStatement(statement, cid, val);
217             break;
218         }
219         case TYPE_INDEX<Assets>: {
220             Assets value;
221             (void)CloudStorageUtils::GetValueFromType<Assets>(typeVal, value);
222             Bytes val;
223             errCode = RuntimeContext::GetInstance()->AssetsToBlob(value, val);
224             if (errCode != E_OK) {
225                 break;
226             }
227             errCode = SQLiteUtils::BindBlobToStatement(statement, cid, val);
228             break;
229         }
230         default: {
231             errCode = SQLiteUtils::MapSQLiteErrno(sqlite3_bind_null(statement, cid));
232             break;
233         }
234     }
235     return errCode;
236 }
237 
GetSelectVBucket(sqlite3_stmt * stmt,VBucket & bucket)238 int SQLiteRelationalUtils::GetSelectVBucket(sqlite3_stmt *stmt, VBucket &bucket)
239 {
240     if (stmt == nullptr) {
241         return -E_INVALID_ARGS;
242     }
243     for (int cid = 0, colCount = sqlite3_column_count(stmt); cid < colCount; ++cid) {
244         Type typeVal;
245         int errCode = GetTypeValByStatement(stmt, cid, typeVal);
246         if (errCode != E_OK) {
247             LOGE("get typeVal from stmt failed");
248             return errCode;
249         }
250         const char *colName = sqlite3_column_name(stmt, cid);
251         bucket.insert_or_assign(colName, std::move(typeVal));
252     }
253     return E_OK;
254 }
255 
GetDbFileName(sqlite3 * db,std::string & fileName)256 bool SQLiteRelationalUtils::GetDbFileName(sqlite3 *db, std::string &fileName)
257 {
258     if (db == nullptr) {
259         return false;
260     }
261 
262     auto dbFilePath = sqlite3_db_filename(db, nullptr);
263     if (dbFilePath == nullptr) {
264         return false;
265     }
266     fileName = std::string(dbFilePath);
267     return true;
268 }
269 
GetTypeValByStatement(sqlite3_stmt * stmt,int cid,Type & typeVal)270 int SQLiteRelationalUtils::GetTypeValByStatement(sqlite3_stmt *stmt, int cid, Type &typeVal)
271 {
272     if (stmt == nullptr || cid < 0 || cid >= sqlite3_column_count(stmt)) {
273         return -E_INVALID_ARGS;
274     }
275     int errCode = E_OK;
276     switch (sqlite3_column_type(stmt, cid)) {
277         case SQLITE_INTEGER: {
278             const char *declType = sqlite3_column_decltype(stmt, cid);
279             if (declType == nullptr) { // LCOV_EXCL_BR_LINE
280                 typeVal = static_cast<int64_t>(sqlite3_column_int64(stmt, cid));
281                 break;
282             }
283             if (strcasecmp(declType, SchemaConstant::KEYWORD_TYPE_BOOL.c_str()) == 0 ||
284                 strcasecmp(declType, SchemaConstant::KEYWORD_TYPE_BOOLEAN.c_str()) == 0) { // LCOV_EXCL_BR_LINE
285                 typeVal = static_cast<bool>(sqlite3_column_int(stmt, cid));
286                 break;
287             }
288             typeVal = static_cast<int64_t>(sqlite3_column_int64(stmt, cid));
289             break;
290         }
291         case SQLITE_FLOAT: {
292             typeVal = sqlite3_column_double(stmt, cid);
293             break;
294         }
295         case SQLITE_BLOB: {
296             errCode = GetBlobByStatement(stmt, cid, typeVal);
297             break;
298         }
299         case SQLITE3_TEXT: {
300             errCode = GetBlobByStatement(stmt, cid, typeVal);
301             if (errCode != E_OK || typeVal.index() != TYPE_INDEX<Nil>) { // LCOV_EXCL_BR_LINE
302                 break;
303             }
304             std::string str;
305             (void)SQLiteUtils::GetColumnTextValue(stmt, cid, str);
306             typeVal = str;
307             break;
308         }
309         default: {
310             typeVal = Nil();
311         }
312     }
313     return errCode;
314 }
315 
GetBlobByStatement(sqlite3_stmt * stmt,int cid,Type & typeVal)316 int SQLiteRelationalUtils::GetBlobByStatement(sqlite3_stmt *stmt, int cid, Type &typeVal)
317 {
318     const char *declType = sqlite3_column_decltype(stmt, cid);
319     int errCode = E_OK;
320     if (declType != nullptr && strcasecmp(declType, CloudDbConstant::ASSET) == 0) { // LCOV_EXCL_BR_LINE
321         std::vector<uint8_t> blobValue;
322         errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
323         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
324             return errCode;
325         }
326         Asset asset;
327         errCode = RuntimeContext::GetInstance()->BlobToAsset(blobValue, asset);
328         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
329             return errCode;
330         }
331         typeVal = asset;
332     } else if (declType != nullptr && strcasecmp(declType, CloudDbConstant::ASSETS) == 0) {
333         std::vector<uint8_t> blobValue;
334         errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
335         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
336             return errCode;
337         }
338         Assets assets;
339         errCode = RuntimeContext::GetInstance()->BlobToAssets(blobValue, assets);
340         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
341             return errCode;
342         }
343         typeVal = assets;
344     } else if (sqlite3_column_type(stmt, cid) == SQLITE_BLOB) {
345         std::vector<uint8_t> blobValue;
346         errCode = SQLiteUtils::GetColumnBlobValue(stmt, cid, blobValue);
347         if (errCode != E_OK) { // LCOV_EXCL_BR_LINE
348             return errCode;
349         }
350         typeVal = blobValue;
351     }
352     return E_OK;
353 }
354 
SelectServerObserver(sqlite3 * db,const std::string & tableName,bool isChanged)355 int SQLiteRelationalUtils::SelectServerObserver(sqlite3 *db, const std::string &tableName, bool isChanged)
356 {
357     if (db == nullptr || tableName.empty()) {
358         return -E_INVALID_ARGS;
359     }
360     std::string sql;
361     if (isChanged) {
362         sql = "SELECT server_observer('" + tableName + "', 1);";
363     } else {
364         sql = "SELECT server_observer('" + tableName + "', 0);";
365     }
366     sqlite3_stmt *stmt = nullptr;
367     int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
368     if (errCode != E_OK) {
369         LOGE("get select server observer stmt failed. %d", errCode);
370         return errCode;
371     }
372     errCode = SQLiteUtils::StepWithRetry(stmt, false);
373     int ret = E_OK;
374     SQLiteUtils::ResetStatement(stmt, true, ret);
375     if (errCode != SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
376         LOGE("select server observer failed. %d", errCode);
377         return SQLiteUtils::MapSQLiteErrno(errCode);
378     }
379     return ret == E_OK ? E_OK : ret;
380 }
381 
AddUpgradeSqlToList(const TableInfo & tableInfo,const std::vector<std::pair<std::string,std::string>> & fieldList,std::vector<std::string> & sqlList)382 void SQLiteRelationalUtils::AddUpgradeSqlToList(const TableInfo &tableInfo,
383     const std::vector<std::pair<std::string, std::string>> &fieldList, std::vector<std::string> &sqlList)
384 {
385     for (const auto &[colName, colType] : fieldList) {
386         auto it = tableInfo.GetFields().find(colName);
387         if (it != tableInfo.GetFields().end()) {
388             continue;
389         }
390         sqlList.push_back("alter table " + tableInfo.GetTableName() + " add " + colName +
391             " " + colType + ";");
392     }
393 }
394 
AnalysisTrackerTable(sqlite3 * db,const TrackerTable & trackerTable,TableInfo & tableInfo)395 int SQLiteRelationalUtils::AnalysisTrackerTable(sqlite3 *db, const TrackerTable &trackerTable, TableInfo &tableInfo)
396 {
397     int errCode = SQLiteUtils::AnalysisSchema(db, trackerTable.GetTableName(), tableInfo, true);
398     if (errCode != E_OK) {
399         LOGE("analysis table schema failed %d.", errCode);
400         return errCode;
401     }
402     tableInfo.SetTrackerTable(trackerTable);
403     errCode = tableInfo.CheckTrackerTable();
404     if (errCode != E_OK) {
405         LOGE("check tracker table schema failed %d.", errCode);
406     }
407     return errCode;
408 }
409 
QueryCount(sqlite3 * db,const std::string & tableName,int64_t & count)410 int SQLiteRelationalUtils::QueryCount(sqlite3 *db, const std::string &tableName, int64_t &count)
411 {
412     std::string sql = "SELECT COUNT(1) FROM " + tableName ;
413     sqlite3_stmt *stmt = nullptr;
414     int errCode = SQLiteUtils::GetStatement(db, sql, stmt);
415     if (errCode != E_OK) {
416         LOGE("Query count failed. %d", errCode);
417         return errCode;
418     }
419     errCode = SQLiteUtils::StepWithRetry(stmt, false);
420     if (errCode == SQLiteUtils::MapSQLiteErrno(SQLITE_ROW)) {
421         count = static_cast<int64_t>(sqlite3_column_int64(stmt, 0));
422         errCode = E_OK;
423     } else {
424         LOGE("Failed to get the count. %d", errCode);
425     }
426     SQLiteUtils::ResetStatement(stmt, true, errCode);
427     return errCode;
428 }
429 } // namespace DistributedDB