1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #define LOG_TAG "RdStatement"
16 #include "rd_statement.h"
17 
18 #include <iomanip>
19 #include <sstream>
20 #include <chrono>
21 #include <cinttypes>
22 #include "logger.h"
23 #include "raw_data_parser.h"
24 #include "rdb_errno.h"
25 #include "rd_connection.h"
26 #include "rd_utils.h"
27 #include "sqlite_global_config.h"
28 #include "sqlite_utils.h"
29 #include "rdb_fault_hiview_reporter.h"
30 #include "sqlite_global_config.h"
31 
32 namespace OHOS {
33 namespace NativeRdb {
34 using namespace OHOS::Rdb;
35 using Reportor = RdbFaultHiViewReporter;
RdStatement()36 RdStatement::RdStatement()
37 {
38 }
39 
~RdStatement()40 RdStatement::~RdStatement()
41 {
42     Finalize();
43 }
44 
45 constexpr size_t PRAGMA_VERSION_SQL_LEN = __builtin_strlen(GlobalExpr::PRAGMA_VERSION);
46 
TryEatSymbol(const std::string & str,char symbol,size_t & curIdx)47 static bool TryEatSymbol(const std::string &str, char symbol, size_t &curIdx)
48 {
49     size_t idx = curIdx;
50     while (idx < str.length()) {
51         if (str[idx] == ' ') {
52             idx++;
53             continue;
54         }
55         if (str[idx] == symbol) {
56             curIdx = idx + 1;
57             return true;
58         }
59         break;
60     }
61     return false;
62 }
63 
TryEatNumber(const std::string & str,int & outNumber,size_t & curIdx)64 static int TryEatNumber(const std::string &str, int &outNumber, size_t &curIdx)
65 {
66     size_t idx = curIdx;
67     uint32_t numSpace = 0;
68     bool hasMeetDigit = false;
69     while (idx < str.length()) {
70         if (str[idx] == ' ' && !hasMeetDigit) {
71             idx++;
72             numSpace++;
73             continue;
74         }
75         if (isdigit(str[idx]) != 0) {
76             idx++;
77             hasMeetDigit = true;
78             continue;
79         }
80         // Indicates that meet first not-digit-char
81         break;
82     }
83     if (!hasMeetDigit) {
84         return false;
85     }
86     outNumber = atoi(str.substr(curIdx).c_str());
87     curIdx = idx;
88     return true;
89 }
90 
EndWithNull(const std::string & str,size_t curIdx)91 static int EndWithNull(const std::string &str, size_t curIdx)
92 {
93     size_t idx = curIdx;
94     while (idx < str.length()) {
95         if (str[idx] == ' ') {
96             idx++;
97             continue;
98         }
99         return false;
100     }
101     return true;
102 }
103 
Prepare(GRD_DB * db,const std::string & newSql)104 int RdStatement::Prepare(GRD_DB *db, const std::string &newSql)
105 {
106     if (newSql.find(GlobalExpr::PRAGMA_VERSION) == 0) {
107         // Indicates that sql is start with pragma version
108         if (newSql.length() == PRAGMA_VERSION_SQL_LEN) {
109             // Indicates that sql is to get version
110             sql_ = newSql;
111             readOnly_ = true;
112             return E_OK;
113         }
114         size_t curIdx = PRAGMA_VERSION_SQL_LEN;
115         int version = 0;
116         if ((!TryEatSymbol(newSql, '=', curIdx)) || (!TryEatNumber(newSql, version, curIdx)) ||
117             (!EndWithNull(newSql, curIdx) && !TryEatSymbol(newSql, ';', curIdx))) {
118             return E_INCORRECT_SQL;
119         }
120 
121         readOnly_ = false;
122         sql_ = newSql;
123         return setPragmas_["user_version"](version);
124     }
125     if (sql_.compare(newSql) == 0) {
126         return E_OK;
127     }
128     GRD_SqlStmt *tmpStmt = nullptr;
129     int ret = RdUtils::RdSqlPrepare(db, newSql.c_str(), newSql.length(), &tmpStmt, nullptr);
130     if (ret != E_OK) {
131         if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
132             Reportor::ReportFault(Reportor::Create(*config_, ret));
133         }
134         if (tmpStmt != nullptr) {
135             (void)RdUtils::RdSqlFinalize(tmpStmt);
136         }
137         LOG_ERROR("Prepare sql for stmt ret is %{public}d", ret);
138         return ret;
139     }
140     Finalize(); // Finalize original stmt
141     sql_ = newSql;
142     stmtHandle_ = tmpStmt;
143     columnCount_ = RdUtils::RdSqlColCnt(tmpStmt);
144     readOnly_ = SqliteUtils::GetSqlStatementType(newSql) == SqliteUtils::STATEMENT_SELECT;
145     if (readOnly_) {
146         isStepInPrepare_ = true;
147         ret = Step();
148         if (ret != E_OK && ret != E_NO_MORE_ROWS) {
149             return ret;
150         }
151         GetProperties();
152         if (ret == E_NO_MORE_ROWS) {
153             Reset();
154         }
155     }
156     return E_OK;
157 }
158 
Finalize()159 int RdStatement::Finalize()
160 {
161     if (stmtHandle_ == nullptr) {
162         return E_OK;
163     }
164     int ret = RdUtils::RdSqlFinalize(stmtHandle_);
165     if (ret != E_OK) {
166         LOG_ERROR("Finalize ret is %{public}d", ret);
167         return ret;
168     }
169     stmtHandle_ = nullptr;
170     sql_ = "";
171     columnCount_ = 0;
172     readOnly_ = false;
173     config_ = nullptr;
174     return E_OK;
175 }
176 
InnerBindBlobTypeArgs(const ValueObject & arg,uint32_t index) const177 int RdStatement::InnerBindBlobTypeArgs(const ValueObject &arg, uint32_t index) const
178 {
179     int ret = E_OK;
180     switch (arg.GetType()) {
181         case ValueObjectType::TYPE_BLOB: {
182             std::vector<uint8_t> blob;
183             arg.GetBlob(blob);
184             ret = RdUtils::RdSqlBindBlob(stmtHandle_, index, static_cast<const void *>(blob.data()), blob.size(),
185                 nullptr);
186             break;
187         }
188         case ValueObjectType::TYPE_BOOL: {
189             bool boolVal = false;
190             arg.GetBool(boolVal);
191             ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, boolVal ? 1 : 0);
192             break;
193         }
194         case ValueObjectType::TYPE_ASSET: {
195             ValueObject::Asset asset;
196             arg.GetAsset(asset);
197             auto rawData = RawDataParser::PackageRawData(asset);
198             ret = RdUtils::RdSqlBindBlob(stmtHandle_, index, static_cast<const void *>(rawData.data()),
199                 rawData.size(), nullptr);
200             break;
201         }
202         case ValueObjectType::TYPE_ASSETS: {
203             ValueObject::Assets assets;
204             arg.GetAssets(assets);
205             auto rawData = RawDataParser::PackageRawData(assets);
206             ret = RdUtils::RdSqlBindBlob(stmtHandle_, index, static_cast<const void *>(rawData.data()),
207                 rawData.size(), nullptr);
208             break;
209         }
210         case ValueObjectType::TYPE_VECS: {
211             ValueObject::FloatVector vectors;
212             arg.GetVecs(vectors);
213             ret = RdUtils::RdSqlBindFloatVector(stmtHandle_, index,
214                 static_cast<float *>(vectors.data()), vectors.size(), nullptr);
215             break;
216         }
217         default: {
218             std::string str;
219             arg.GetString(str);
220             ret = RdUtils::RdSqlBindText(stmtHandle_, index, str.c_str(), str.length(), nullptr);
221             break;
222         }
223     }
224     return ret;
225 }
226 
IsValid(int index) const227 int RdStatement::IsValid(int index) const
228 {
229     if (stmtHandle_ == nullptr) {
230         LOG_ERROR("statement already close.");
231         return E_ALREADY_CLOSED;
232     }
233     if (index < 0 || index >= columnCount_) {
234         LOG_ERROR("index (%{public}d) >= columnCount (%{public}d)", index, columnCount_);
235         return E_COLUMN_OUT_RANGE;
236     }
237     return E_OK;
238 }
239 
Prepare(const std::string & sql)240 int32_t RdStatement::Prepare(const std::string& sql)
241 {
242     if (dbHandle_ == nullptr) {
243         return E_ERROR;
244     }
245     return Prepare(dbHandle_, sql);
246 }
247 
Bind(const std::vector<ValueObject> & args)248 int32_t RdStatement::Bind(const std::vector<ValueObject>& args)
249 {
250     std::vector<std::reference_wrapper<ValueObject>> refArgs;
251     for (auto &object : args) {
252         refArgs.emplace_back(std::ref(const_cast<ValueObject&>(object)));
253     }
254     return Bind(refArgs);
255 }
256 
Bind(const std::vector<std::reference_wrapper<ValueObject>> & args)257 int32_t RdStatement::Bind(const std::vector<std::reference_wrapper<ValueObject>>& args)
258 {
259     uint32_t index = 1;
260     int ret = E_OK;
261     for (auto &arg : args) {
262         switch (arg.get().GetType()) {
263             case ValueObjectType::TYPE_NULL: {
264                 ret = RdUtils::RdSqlBindNull(stmtHandle_, index);
265                 break;
266             }
267             case ValueObjectType::TYPE_INT: {
268                 int64_t value = 0;
269                 arg.get().GetLong(value);
270                 ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, value);
271                 break;
272             }
273             case ValueObjectType::TYPE_DOUBLE: {
274                 double doubleVal = 0;
275                 arg.get().GetDouble(doubleVal);
276                 ret = RdUtils::RdSqlBindDouble(stmtHandle_, index, doubleVal);
277                 break;
278             }
279             default: {
280                 ret = InnerBindBlobTypeArgs(arg, index);
281                 break;
282             }
283         }
284         if (ret != E_OK) {
285             LOG_ERROR("bind ret is %{public}d", ret);
286             return ret;
287         }
288         index++;
289     }
290     return E_OK;
291 }
292 
Count()293 std::pair<int32_t, int32_t> RdStatement::Count()
294 {
295     return { E_NOT_SUPPORT, INVALID_COUNT };
296 }
297 
Step()298 int32_t RdStatement::Step()
299 {
300     if (stmtHandle_ == nullptr) {
301         return E_OK;
302     }
303     if (isStepInPrepare_ && stepCnt_ == 1) {
304         stepCnt_++;
305         return E_OK;
306     }
307     int ret = RdUtils::RdSqlStep(stmtHandle_);
308     if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
309         Reportor::ReportFault(Reportor::Create(*config_, ret));
310     }
311     stepCnt_++;
312     return ret;
313 }
314 
Reset()315 int32_t RdStatement::Reset()
316 {
317     if (stmtHandle_ == nullptr) {
318         return E_OK;
319     }
320     stepCnt_ = 0;
321     isStepInPrepare_ = false;
322     return RdUtils::RdSqlReset(stmtHandle_);
323 }
324 
Execute(const std::vector<ValueObject> & args)325 int32_t RdStatement::Execute(const std::vector<ValueObject>& args)
326 {
327     std::vector<std::reference_wrapper<ValueObject>> refArgs;
328     for (auto &object : args) {
329         refArgs.emplace_back(std::ref(const_cast<ValueObject&>(object)));
330     }
331     return Execute(refArgs);
332 }
333 
Execute(const std::vector<std::reference_wrapper<ValueObject>> & args)334 int32_t RdStatement::Execute(const std::vector<std::reference_wrapper<ValueObject>>& args)
335 {
336     if (!readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
337         // It has already set version in prepare procedure
338         // Current modification is only temporary for unification between rd and sqlite,
339         // rd kernal will support pragma in later version
340         return E_OK;
341     }
342     int ret = Bind(args);
343     if (ret != E_OK) {
344         LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
345         return ret;
346     }
347     ret = Step();
348     if (ret != E_OK && ret != E_NO_MORE_ROWS) {
349         LOG_ERROR("RdConnection Execute : err %{public}d", ret);
350     }
351     return ret;
352 }
353 
ExecuteForValue(const std::vector<ValueObject> & args)354 std::pair<int, ValueObject> RdStatement::ExecuteForValue(const std::vector<ValueObject>& args)
355 {
356     int ret = E_OK;
357     if (readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
358         int version = 0;
359         ret = getPragmas_["user_version"](version);
360         if (ret != E_OK) {
361             LOG_ERROR("RdConnection unable to GetVersion : err %{public}d", ret);
362             return { ret, ValueObject() };
363         }
364         return { ret, ValueObject(version) };
365     }
366     ret = Bind(args);
367     if (ret != E_OK) {
368         LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
369         return { ret, ValueObject() };
370     }
371     ret = Step();
372     if (ret != E_OK && ret != E_NO_MORE_ROWS) {
373         LOG_ERROR("RdConnection Execute : err %{public}d", ret);
374         return { ret, ValueObject() };
375     }
376     return GetColumn(0);
377 }
378 
Changes() const379 int32_t RdStatement::Changes() const
380 {
381     return 0;
382 }
383 
LastInsertRowId() const384 int64_t RdStatement::LastInsertRowId() const
385 {
386     return 0;
387 }
388 
GetColumnCount() const389 int32_t RdStatement::GetColumnCount() const
390 {
391     return columnCount_;
392 }
393 
GetColumnName(int32_t index) const394 std::pair<int32_t, std::string> RdStatement::GetColumnName(int32_t index) const
395 {
396     int ret = IsValid(index);
397     if (ret != E_OK) {
398         return { ret, "" };
399     }
400     const char* name = RdUtils::RdSqlColName(stmtHandle_, index);
401     if (name == nullptr) {
402         LOG_ERROR("column_name is null.");
403         return { E_ERROR, "" };
404     }
405     return { E_OK, name };
406 }
407 
GetColumnType(int32_t index) const408 std::pair<int32_t, int32_t> RdStatement::GetColumnType(int32_t index) const
409 {
410     int ret = IsValid(index);
411     if (ret != E_OK) {
412         return { ret, static_cast<int32_t>(ColumnType::TYPE_NULL) };
413     }
414     ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
415     switch (type) {
416         case ColumnType::TYPE_INTEGER:
417         case ColumnType::TYPE_FLOAT:
418         case ColumnType::TYPE_NULL:
419         case ColumnType::TYPE_STRING:
420         case ColumnType::TYPE_BLOB:
421         case ColumnType::TYPE_FLOAT32_ARRAY:
422             break;
423         default:
424             LOG_ERROR("invalid type %{public}d.", type);
425             return { E_ERROR, static_cast<int32_t>(ColumnType::TYPE_NULL) };
426     }
427     return { ret, static_cast<int32_t>(type) };
428 }
429 
GetSize(int32_t index) const430 std::pair<int32_t, size_t> RdStatement::GetSize(int32_t index) const
431 {
432     int ret = IsValid(index);
433     if (ret != E_OK) {
434         return { ret, 0 };
435     }
436     ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
437     if (type == ColumnType::TYPE_BLOB || type == ColumnType::TYPE_STRING || type == ColumnType::TYPE_NULL ||
438         type == ColumnType::TYPE_FLOAT32_ARRAY) {
439         return { E_OK, static_cast<size_t>(RdUtils::RdSqlColBytes(stmtHandle_, index)) };
440     }
441     return { E_INVALID_COLUMN_TYPE, 0 };
442 }
443 
GetColumn(int32_t index) const444 std::pair<int32_t, ValueObject> RdStatement::GetColumn(int32_t index) const
445 {
446     ValueObject object;
447     int ret = IsValid(index);
448     if (ret != E_OK) {
449         return { ret, object };
450     }
451 
452     ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
453     switch (type) {
454         case ColumnType::TYPE_FLOAT:
455             object = RdUtils::RdSqlColDouble(stmtHandle_, index);
456             break;
457         case ColumnType::TYPE_INTEGER:
458             object = static_cast<int64_t>(RdUtils::RdSqlColInt64(stmtHandle_, index));
459             break;
460         case ColumnType::TYPE_STRING:
461             object = reinterpret_cast<const char *>(RdUtils::RdSqlColText(stmtHandle_, index));
462             break;
463         case ColumnType::TYPE_NULL:
464             break;
465         case ColumnType::TYPE_FLOAT32_ARRAY: {
466             uint32_t dim = 0;
467             auto vectors =
468                 reinterpret_cast<const float *>(RdUtils::RdSqlColumnFloatVector(stmtHandle_, index, &dim));
469             std::vector<float> vecData;
470             if (dim > 0 || vectors != nullptr) {
471                 vecData.resize(dim);
472                 vecData.assign(vectors, vectors + dim);
473             }
474             object = std::move(vecData);
475             break;
476         }
477         case ColumnType::TYPE_BLOB: {
478             int size = RdUtils::RdSqlColBytes(stmtHandle_, index);
479             auto blob = static_cast<const uint8_t *>(RdUtils::RdSqlColBlob(stmtHandle_, index));
480             std::vector<uint8_t> rawData;
481             if (size > 0 || blob != nullptr) {
482                 rawData.resize(size);
483                 rawData.assign(blob, blob + size);
484             }
485             object = std::move(rawData);
486             break;
487         }
488         default:
489             break;
490     }
491     return { ret, std::move(object) };
492 }
493 
ReadOnly() const494 bool RdStatement::ReadOnly() const
495 {
496     return readOnly_;
497 }
498 
SupportBlockInfo() const499 bool RdStatement::SupportBlockInfo() const
500 {
501     return false;
502 }
503 
FillBlockInfo(SharedBlockInfo * info) const504 int32_t RdStatement::FillBlockInfo(SharedBlockInfo* info) const
505 {
506     return E_NOT_SUPPORT;
507 }
508 
GetProperties()509 void RdStatement::GetProperties()
510 {
511     columnCount_ = RdUtils::RdSqlColCnt(stmtHandle_);
512 }
513 } // namespace NativeRdb
514 } // namespace OHOS
515