1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #ifdef RELATIONAL_STORE
16 #include "prepared_stmt.h"
17 #include "db_constant.h"
18
19 namespace DistributedDB {
PreparedStmt(ExecutorOperation opCode,const std::string & sql,const std::vector<std::string> & bindArgs)20 PreparedStmt::PreparedStmt(ExecutorOperation opCode, const std::string &sql, const std::vector<std::string> &bindArgs)
21 : opCode_(opCode), sql_(sql), bindArgs_(bindArgs) {}
22
23
SetOpCode(ExecutorOperation opCode)24 void PreparedStmt::SetOpCode(ExecutorOperation opCode)
25 {
26 opCode_ = opCode;
27 }
28
SetSql(std::string sql)29 void PreparedStmt::SetSql(std::string sql)
30 {
31 sql_ = std::move(sql);
32 }
33
SetBindArgs(std::vector<std::string> bindArgs)34 void PreparedStmt::SetBindArgs(std::vector<std::string> bindArgs)
35 {
36 bindArgs_ = std::move(bindArgs);
37 }
38
GetOpCode() const39 PreparedStmt::ExecutorOperation PreparedStmt::GetOpCode() const
40 {
41 return opCode_;
42 }
43
GetSql() const44 const std::string &PreparedStmt::GetSql() const
45 {
46 return sql_;
47 }
48
GetBindArgs() const49 const std::vector<std::string> &PreparedStmt::GetBindArgs() const
50 {
51 return bindArgs_;
52 }
53
IsValid() const54 bool PreparedStmt::IsValid() const
55 {
56 return opCode_ == ExecutorOperation::QUERY && !sql_.empty() && bindArgs_.size() <= DBConstant::MAX_SQL_ARGS_COUNT;
57 }
58
CalcLength() const59 uint32_t PreparedStmt::CalcLength() const
60 {
61 uint32_t length = Parcel::GetIntLen() + // current version
62 Parcel::GetIntLen() + // opcode_
63 Parcel::GetStringLen(sql_) + // sql_
64 Parcel::GetIntLen(); // bindArgs_.size()
65 for (const auto &bindArg : bindArgs_) {
66 length += Parcel::GetStringLen(bindArg); // bindArgs_
67 if (length > INT32_MAX) {
68 return 0u;
69 }
70 }
71 return Parcel::GetEightByteAlign(length);
72 }
73
74 // Before call this func. You should check if the object is valid.
Serialize(Parcel & parcel) const75 int PreparedStmt::Serialize(Parcel &parcel) const
76 {
77 // version
78 (void)parcel.WriteInt(CURRENT_VERSION);
79
80 // opcode
81 (void)parcel.WriteInt(static_cast<int>(opCode_));
82
83 // string
84 (void)parcel.WriteString(sql_);
85
86 // bindArgs
87 (void)parcel.WriteInt(static_cast<int>(bindArgs_.size()));
88 for (const auto &bindArg : bindArgs_) {
89 (void)parcel.WriteString(bindArg);
90 if (parcel.IsError()) {
91 return -E_PARSE_FAIL;
92 }
93 }
94
95 parcel.EightByteAlign();
96 if (parcel.IsError()) {
97 return -E_PARSE_FAIL;
98 }
99 return E_OK;
100 }
101
DeSerialize(Parcel & parcel)102 int PreparedStmt::DeSerialize(Parcel &parcel)
103 {
104 // clear the object
105 bindArgs_.clear();
106
107 // version
108 int version = 0;
109 (void)parcel.ReadInt(version);
110 if (parcel.IsError() || version <= 0 || version > CURRENT_VERSION) {
111 return -E_PARSE_FAIL;
112 }
113
114 // VERSION 1
115 if (version >= VERSION_1) {
116 // opcode
117 int opCode = 0;
118 (void)parcel.ReadInt(opCode);
119 if (parcel.IsError() || opCode <= MIN_LIMIT || opCode >= MAX_LIMIT) {
120 return -E_PARSE_FAIL;
121 }
122 opCode_ = static_cast<ExecutorOperation>(opCode);
123
124 // sql
125 (void)parcel.ReadString(sql_);
126
127 // bindArgs
128 int argsCount = 0;
129 (void)parcel.ReadInt(argsCount);
130 if (parcel.IsError() || argsCount < 0 || argsCount > static_cast<int>(DBConstant::MAX_SQL_ARGS_COUNT)) {
131 return -E_PARSE_FAIL;
132 }
133 for (int i = 0; i < argsCount; ++i) {
134 std::string bindArg;
135 (void)parcel.ReadString(bindArg);
136 if (parcel.IsError()) {
137 return -E_PARSE_FAIL;
138 }
139 bindArgs_.emplace_back(std::move(bindArg));
140 }
141 }
142
143 parcel.EightByteAlign();
144 if (parcel.IsError()) {
145 return -E_PARSE_FAIL;
146 }
147 return E_OK;
148 }
149 }
150 #endif