1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #define LOG_TAG "RdStatement"
16 #include "rd_statement.h"
17
18 #include <iomanip>
19 #include <sstream>
20 #include <chrono>
21 #include <cinttypes>
22 #include "logger.h"
23 #include "raw_data_parser.h"
24 #include "rdb_errno.h"
25 #include "rd_connection.h"
26 #include "rd_utils.h"
27 #include "sqlite_global_config.h"
28 #include "sqlite_utils.h"
29 #include "rdb_fault_hiview_reporter.h"
30 #include "sqlite_global_config.h"
31
32 namespace OHOS {
33 namespace NativeRdb {
34 using namespace OHOS::Rdb;
35 using Reportor = RdbFaultHiViewReporter;
RdStatement()36 RdStatement::RdStatement()
37 {
38 }
39
~RdStatement()40 RdStatement::~RdStatement()
41 {
42 Finalize();
43 }
44
45 constexpr size_t PRAGMA_VERSION_SQL_LEN = __builtin_strlen(GlobalExpr::PRAGMA_VERSION);
46
TryEatSymbol(const std::string & str,char symbol,size_t & curIdx)47 static bool TryEatSymbol(const std::string &str, char symbol, size_t &curIdx)
48 {
49 size_t idx = curIdx;
50 while (idx < str.length()) {
51 if (str[idx] == ' ') {
52 idx++;
53 continue;
54 }
55 if (str[idx] == symbol) {
56 curIdx = idx + 1;
57 return true;
58 }
59 break;
60 }
61 return false;
62 }
63
TryEatNumber(const std::string & str,int & outNumber,size_t & curIdx)64 static int TryEatNumber(const std::string &str, int &outNumber, size_t &curIdx)
65 {
66 size_t idx = curIdx;
67 uint32_t numSpace = 0;
68 bool hasMeetDigit = false;
69 while (idx < str.length()) {
70 if (str[idx] == ' ' && !hasMeetDigit) {
71 idx++;
72 numSpace++;
73 continue;
74 }
75 if (isdigit(str[idx]) != 0) {
76 idx++;
77 hasMeetDigit = true;
78 continue;
79 }
80 // Indicates that meet first not-digit-char
81 break;
82 }
83 if (!hasMeetDigit) {
84 return false;
85 }
86 outNumber = atoi(str.substr(curIdx).c_str());
87 curIdx = idx;
88 return true;
89 }
90
EndWithNull(const std::string & str,size_t curIdx)91 static int EndWithNull(const std::string &str, size_t curIdx)
92 {
93 size_t idx = curIdx;
94 while (idx < str.length()) {
95 if (str[idx] == ' ') {
96 idx++;
97 continue;
98 }
99 return false;
100 }
101 return true;
102 }
103
Prepare(GRD_DB * db,const std::string & newSql)104 int RdStatement::Prepare(GRD_DB *db, const std::string &newSql)
105 {
106 if (newSql.find(GlobalExpr::PRAGMA_VERSION) == 0) {
107 // Indicates that sql is start with pragma version
108 if (newSql.length() == PRAGMA_VERSION_SQL_LEN) {
109 // Indicates that sql is to get version
110 sql_ = newSql;
111 readOnly_ = true;
112 return E_OK;
113 }
114 size_t curIdx = PRAGMA_VERSION_SQL_LEN;
115 int version = 0;
116 if ((!TryEatSymbol(newSql, '=', curIdx)) || (!TryEatNumber(newSql, version, curIdx)) ||
117 (!EndWithNull(newSql, curIdx) && !TryEatSymbol(newSql, ';', curIdx))) {
118 return E_INCORRECT_SQL;
119 }
120
121 readOnly_ = false;
122 sql_ = newSql;
123 return setPragmas_["user_version"](version);
124 }
125 if (sql_.compare(newSql) == 0) {
126 return E_OK;
127 }
128 GRD_SqlStmt *tmpStmt = nullptr;
129 int ret = RdUtils::RdSqlPrepare(db, newSql.c_str(), newSql.length(), &tmpStmt, nullptr);
130 if (ret != E_OK) {
131 if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
132 Reportor::ReportFault(Reportor::Create(*config_, ret));
133 }
134 if (tmpStmt != nullptr) {
135 (void)RdUtils::RdSqlFinalize(tmpStmt);
136 }
137 LOG_ERROR("Prepare sql for stmt ret is %{public}d", ret);
138 return ret;
139 }
140 Finalize(); // Finalize original stmt
141 sql_ = newSql;
142 stmtHandle_ = tmpStmt;
143 columnCount_ = RdUtils::RdSqlColCnt(tmpStmt);
144 readOnly_ = SqliteUtils::GetSqlStatementType(newSql) == SqliteUtils::STATEMENT_SELECT;
145 if (readOnly_) {
146 isStepInPrepare_ = true;
147 ret = Step();
148 if (ret != E_OK && ret != E_NO_MORE_ROWS) {
149 return ret;
150 }
151 GetProperties();
152 if (ret == E_NO_MORE_ROWS) {
153 Reset();
154 }
155 }
156 return E_OK;
157 }
158
Finalize()159 int RdStatement::Finalize()
160 {
161 if (stmtHandle_ == nullptr) {
162 return E_OK;
163 }
164 int ret = RdUtils::RdSqlFinalize(stmtHandle_);
165 if (ret != E_OK) {
166 LOG_ERROR("Finalize ret is %{public}d", ret);
167 return ret;
168 }
169 stmtHandle_ = nullptr;
170 sql_ = "";
171 columnCount_ = 0;
172 readOnly_ = false;
173 config_ = nullptr;
174 return E_OK;
175 }
176
InnerBindBlobTypeArgs(const ValueObject & arg,uint32_t index) const177 int RdStatement::InnerBindBlobTypeArgs(const ValueObject &arg, uint32_t index) const
178 {
179 int ret = E_OK;
180 switch (arg.GetType()) {
181 case ValueObjectType::TYPE_BLOB: {
182 std::vector<uint8_t> blob;
183 arg.GetBlob(blob);
184 ret = RdUtils::RdSqlBindBlob(stmtHandle_, index, static_cast<const void *>(blob.data()), blob.size(),
185 nullptr);
186 break;
187 }
188 case ValueObjectType::TYPE_BOOL: {
189 bool boolVal = false;
190 arg.GetBool(boolVal);
191 ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, boolVal ? 1 : 0);
192 break;
193 }
194 case ValueObjectType::TYPE_ASSET: {
195 ValueObject::Asset asset;
196 arg.GetAsset(asset);
197 auto rawData = RawDataParser::PackageRawData(asset);
198 ret = RdUtils::RdSqlBindBlob(stmtHandle_, index, static_cast<const void *>(rawData.data()),
199 rawData.size(), nullptr);
200 break;
201 }
202 case ValueObjectType::TYPE_ASSETS: {
203 ValueObject::Assets assets;
204 arg.GetAssets(assets);
205 auto rawData = RawDataParser::PackageRawData(assets);
206 ret = RdUtils::RdSqlBindBlob(stmtHandle_, index, static_cast<const void *>(rawData.data()),
207 rawData.size(), nullptr);
208 break;
209 }
210 case ValueObjectType::TYPE_VECS: {
211 ValueObject::FloatVector vectors;
212 arg.GetVecs(vectors);
213 ret = RdUtils::RdSqlBindFloatVector(stmtHandle_, index,
214 static_cast<float *>(vectors.data()), vectors.size(), nullptr);
215 break;
216 }
217 default: {
218 std::string str;
219 arg.GetString(str);
220 ret = RdUtils::RdSqlBindText(stmtHandle_, index, str.c_str(), str.length(), nullptr);
221 break;
222 }
223 }
224 return ret;
225 }
226
IsValid(int index) const227 int RdStatement::IsValid(int index) const
228 {
229 if (stmtHandle_ == nullptr) {
230 LOG_ERROR("statement already close.");
231 return E_ALREADY_CLOSED;
232 }
233 if (index < 0 || index >= columnCount_) {
234 LOG_ERROR("index (%{public}d) >= columnCount (%{public}d)", index, columnCount_);
235 return E_COLUMN_OUT_RANGE;
236 }
237 return E_OK;
238 }
239
Prepare(const std::string & sql)240 int32_t RdStatement::Prepare(const std::string& sql)
241 {
242 if (dbHandle_ == nullptr) {
243 return E_ERROR;
244 }
245 return Prepare(dbHandle_, sql);
246 }
247
Bind(const std::vector<ValueObject> & args)248 int32_t RdStatement::Bind(const std::vector<ValueObject>& args)
249 {
250 std::vector<std::reference_wrapper<ValueObject>> refArgs;
251 for (auto &object : args) {
252 refArgs.emplace_back(std::ref(const_cast<ValueObject&>(object)));
253 }
254 return Bind(refArgs);
255 }
256
Bind(const std::vector<std::reference_wrapper<ValueObject>> & args)257 int32_t RdStatement::Bind(const std::vector<std::reference_wrapper<ValueObject>>& args)
258 {
259 uint32_t index = 1;
260 int ret = E_OK;
261 for (auto &arg : args) {
262 switch (arg.get().GetType()) {
263 case ValueObjectType::TYPE_NULL: {
264 ret = RdUtils::RdSqlBindNull(stmtHandle_, index);
265 break;
266 }
267 case ValueObjectType::TYPE_INT: {
268 int64_t value = 0;
269 arg.get().GetLong(value);
270 ret = RdUtils::RdSqlBindInt64(stmtHandle_, index, value);
271 break;
272 }
273 case ValueObjectType::TYPE_DOUBLE: {
274 double doubleVal = 0;
275 arg.get().GetDouble(doubleVal);
276 ret = RdUtils::RdSqlBindDouble(stmtHandle_, index, doubleVal);
277 break;
278 }
279 default: {
280 ret = InnerBindBlobTypeArgs(arg, index);
281 break;
282 }
283 }
284 if (ret != E_OK) {
285 LOG_ERROR("bind ret is %{public}d", ret);
286 return ret;
287 }
288 index++;
289 }
290 return E_OK;
291 }
292
Count()293 std::pair<int32_t, int32_t> RdStatement::Count()
294 {
295 return { E_NOT_SUPPORT, INVALID_COUNT };
296 }
297
Step()298 int32_t RdStatement::Step()
299 {
300 if (stmtHandle_ == nullptr) {
301 return E_OK;
302 }
303 if (isStepInPrepare_ && stepCnt_ == 1) {
304 stepCnt_++;
305 return E_OK;
306 }
307 int ret = RdUtils::RdSqlStep(stmtHandle_);
308 if (ret == E_SQLITE_CORRUPT && config_ != nullptr) {
309 Reportor::ReportFault(Reportor::Create(*config_, ret));
310 }
311 stepCnt_++;
312 return ret;
313 }
314
Reset()315 int32_t RdStatement::Reset()
316 {
317 if (stmtHandle_ == nullptr) {
318 return E_OK;
319 }
320 stepCnt_ = 0;
321 isStepInPrepare_ = false;
322 return RdUtils::RdSqlReset(stmtHandle_);
323 }
324
Execute(const std::vector<ValueObject> & args)325 int32_t RdStatement::Execute(const std::vector<ValueObject>& args)
326 {
327 std::vector<std::reference_wrapper<ValueObject>> refArgs;
328 for (auto &object : args) {
329 refArgs.emplace_back(std::ref(const_cast<ValueObject&>(object)));
330 }
331 return Execute(refArgs);
332 }
333
Execute(const std::vector<std::reference_wrapper<ValueObject>> & args)334 int32_t RdStatement::Execute(const std::vector<std::reference_wrapper<ValueObject>>& args)
335 {
336 if (!readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
337 // It has already set version in prepare procedure
338 // Current modification is only temporary for unification between rd and sqlite,
339 // rd kernal will support pragma in later version
340 return E_OK;
341 }
342 int ret = Bind(args);
343 if (ret != E_OK) {
344 LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
345 return ret;
346 }
347 ret = Step();
348 if (ret != E_OK && ret != E_NO_MORE_ROWS) {
349 LOG_ERROR("RdConnection Execute : err %{public}d", ret);
350 }
351 return ret;
352 }
353
ExecuteForValue(const std::vector<ValueObject> & args)354 std::pair<int, ValueObject> RdStatement::ExecuteForValue(const std::vector<ValueObject>& args)
355 {
356 int ret = E_OK;
357 if (readOnly_ && strcmp(sql_.c_str(), GlobalExpr::PRAGMA_VERSION) == 0) {
358 int version = 0;
359 ret = getPragmas_["user_version"](version);
360 if (ret != E_OK) {
361 LOG_ERROR("RdConnection unable to GetVersion : err %{public}d", ret);
362 return { ret, ValueObject() };
363 }
364 return { ret, ValueObject(version) };
365 }
366 ret = Bind(args);
367 if (ret != E_OK) {
368 LOG_ERROR("RdConnection unable to prepare and bind stmt : err %{public}d", ret);
369 return { ret, ValueObject() };
370 }
371 ret = Step();
372 if (ret != E_OK && ret != E_NO_MORE_ROWS) {
373 LOG_ERROR("RdConnection Execute : err %{public}d", ret);
374 return { ret, ValueObject() };
375 }
376 return GetColumn(0);
377 }
378
Changes() const379 int32_t RdStatement::Changes() const
380 {
381 return 0;
382 }
383
LastInsertRowId() const384 int64_t RdStatement::LastInsertRowId() const
385 {
386 return 0;
387 }
388
GetColumnCount() const389 int32_t RdStatement::GetColumnCount() const
390 {
391 return columnCount_;
392 }
393
GetColumnName(int32_t index) const394 std::pair<int32_t, std::string> RdStatement::GetColumnName(int32_t index) const
395 {
396 int ret = IsValid(index);
397 if (ret != E_OK) {
398 return { ret, "" };
399 }
400 const char* name = RdUtils::RdSqlColName(stmtHandle_, index);
401 if (name == nullptr) {
402 LOG_ERROR("column_name is null.");
403 return { E_ERROR, "" };
404 }
405 return { E_OK, name };
406 }
407
GetColumnType(int32_t index) const408 std::pair<int32_t, int32_t> RdStatement::GetColumnType(int32_t index) const
409 {
410 int ret = IsValid(index);
411 if (ret != E_OK) {
412 return { ret, static_cast<int32_t>(ColumnType::TYPE_NULL) };
413 }
414 ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
415 switch (type) {
416 case ColumnType::TYPE_INTEGER:
417 case ColumnType::TYPE_FLOAT:
418 case ColumnType::TYPE_NULL:
419 case ColumnType::TYPE_STRING:
420 case ColumnType::TYPE_BLOB:
421 case ColumnType::TYPE_FLOAT32_ARRAY:
422 break;
423 default:
424 LOG_ERROR("invalid type %{public}d.", type);
425 return { E_ERROR, static_cast<int32_t>(ColumnType::TYPE_NULL) };
426 }
427 return { ret, static_cast<int32_t>(type) };
428 }
429
GetSize(int32_t index) const430 std::pair<int32_t, size_t> RdStatement::GetSize(int32_t index) const
431 {
432 int ret = IsValid(index);
433 if (ret != E_OK) {
434 return { ret, 0 };
435 }
436 ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
437 if (type == ColumnType::TYPE_BLOB || type == ColumnType::TYPE_STRING || type == ColumnType::TYPE_NULL ||
438 type == ColumnType::TYPE_FLOAT32_ARRAY) {
439 return { E_OK, static_cast<size_t>(RdUtils::RdSqlColBytes(stmtHandle_, index)) };
440 }
441 return { E_INVALID_COLUMN_TYPE, 0 };
442 }
443
GetColumn(int32_t index) const444 std::pair<int32_t, ValueObject> RdStatement::GetColumn(int32_t index) const
445 {
446 ValueObject object;
447 int ret = IsValid(index);
448 if (ret != E_OK) {
449 return { ret, object };
450 }
451
452 ColumnType type = RdUtils::RdSqlColType(stmtHandle_, index);
453 switch (type) {
454 case ColumnType::TYPE_FLOAT:
455 object = RdUtils::RdSqlColDouble(stmtHandle_, index);
456 break;
457 case ColumnType::TYPE_INTEGER:
458 object = static_cast<int64_t>(RdUtils::RdSqlColInt64(stmtHandle_, index));
459 break;
460 case ColumnType::TYPE_STRING:
461 object = reinterpret_cast<const char *>(RdUtils::RdSqlColText(stmtHandle_, index));
462 break;
463 case ColumnType::TYPE_NULL:
464 break;
465 case ColumnType::TYPE_FLOAT32_ARRAY: {
466 uint32_t dim = 0;
467 auto vectors =
468 reinterpret_cast<const float *>(RdUtils::RdSqlColumnFloatVector(stmtHandle_, index, &dim));
469 std::vector<float> vecData;
470 if (dim > 0 || vectors != nullptr) {
471 vecData.resize(dim);
472 vecData.assign(vectors, vectors + dim);
473 }
474 object = std::move(vecData);
475 break;
476 }
477 case ColumnType::TYPE_BLOB: {
478 int size = RdUtils::RdSqlColBytes(stmtHandle_, index);
479 auto blob = static_cast<const uint8_t *>(RdUtils::RdSqlColBlob(stmtHandle_, index));
480 std::vector<uint8_t> rawData;
481 if (size > 0 || blob != nullptr) {
482 rawData.resize(size);
483 rawData.assign(blob, blob + size);
484 }
485 object = std::move(rawData);
486 break;
487 }
488 default:
489 break;
490 }
491 return { ret, std::move(object) };
492 }
493
ReadOnly() const494 bool RdStatement::ReadOnly() const
495 {
496 return readOnly_;
497 }
498
SupportBlockInfo() const499 bool RdStatement::SupportBlockInfo() const
500 {
501 return false;
502 }
503
FillBlockInfo(SharedBlockInfo * info) const504 int32_t RdStatement::FillBlockInfo(SharedBlockInfo* info) const
505 {
506 return E_NOT_SUPPORT;
507 }
508
GetProperties()509 void RdStatement::GetProperties()
510 {
511 columnCount_ = RdUtils::RdSqlColCnt(stmtHandle_);
512 }
513 } // namespace NativeRdb
514 } // namespace OHOS
515