1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #ifdef RELATIONAL_STORE
16 #include "prepared_stmt.h"
17 #include "db_constant.h"
18 
19 namespace DistributedDB {
PreparedStmt(ExecutorOperation opCode,const std::string & sql,const std::vector<std::string> & bindArgs)20 PreparedStmt::PreparedStmt(ExecutorOperation opCode, const std::string &sql, const std::vector<std::string> &bindArgs)
21     : opCode_(opCode), sql_(sql), bindArgs_(bindArgs) {}
22 
23 
SetOpCode(ExecutorOperation opCode)24 void PreparedStmt::SetOpCode(ExecutorOperation opCode)
25 {
26     opCode_ = opCode;
27 }
28 
SetSql(std::string sql)29 void PreparedStmt::SetSql(std::string sql)
30 {
31     sql_ = std::move(sql);
32 }
33 
SetBindArgs(std::vector<std::string> bindArgs)34 void PreparedStmt::SetBindArgs(std::vector<std::string> bindArgs)
35 {
36     bindArgs_ = std::move(bindArgs);
37 }
38 
GetOpCode() const39 PreparedStmt::ExecutorOperation PreparedStmt::GetOpCode() const
40 {
41     return opCode_;
42 }
43 
GetSql() const44 const std::string &PreparedStmt::GetSql() const
45 {
46     return sql_;
47 }
48 
GetBindArgs() const49 const std::vector<std::string> &PreparedStmt::GetBindArgs() const
50 {
51     return bindArgs_;
52 }
53 
IsValid() const54 bool PreparedStmt::IsValid() const
55 {
56     return opCode_ == ExecutorOperation::QUERY && !sql_.empty() && bindArgs_.size() <= DBConstant::MAX_SQL_ARGS_COUNT;
57 }
58 
CalcLength() const59 uint32_t PreparedStmt::CalcLength() const
60 {
61     uint32_t length = Parcel::GetIntLen() +         // current version
62                       Parcel::GetIntLen() +         // opcode_
63                       Parcel::GetStringLen(sql_) +  // sql_
64                       Parcel::GetIntLen();          // bindArgs_.size()
65     for (const auto &bindArg : bindArgs_) {
66         length += Parcel::GetStringLen(bindArg);    // bindArgs_
67         if (length > INT32_MAX) {
68             return 0u;
69         }
70     }
71     return Parcel::GetEightByteAlign(length);
72 }
73 
74 // Before call this func. You should check if the object is valid.
Serialize(Parcel & parcel) const75 int PreparedStmt::Serialize(Parcel &parcel) const
76 {
77     // version
78     (void)parcel.WriteInt(CURRENT_VERSION);
79 
80     // opcode
81     (void)parcel.WriteInt(static_cast<int>(opCode_));
82 
83     // string
84     (void)parcel.WriteString(sql_);
85 
86     // bindArgs
87     (void)parcel.WriteInt(static_cast<int>(bindArgs_.size()));
88     for (const auto &bindArg : bindArgs_) {
89         (void)parcel.WriteString(bindArg);
90         if (parcel.IsError()) {
91             return -E_PARSE_FAIL;
92         }
93     }
94 
95     parcel.EightByteAlign();
96     if (parcel.IsError()) {
97         return -E_PARSE_FAIL;
98     }
99     return E_OK;
100 }
101 
DeSerialize(Parcel & parcel)102 int PreparedStmt::DeSerialize(Parcel &parcel)
103 {
104     // clear the object
105     bindArgs_.clear();
106 
107     // version
108     int version = 0;
109     (void)parcel.ReadInt(version);
110     if (parcel.IsError() || version <= 0 || version > CURRENT_VERSION) {
111         return -E_PARSE_FAIL;
112     }
113 
114     // VERSION 1
115     if (version >= VERSION_1) {
116         // opcode
117         int opCode = 0;
118         (void)parcel.ReadInt(opCode);
119         if (parcel.IsError() || opCode <= MIN_LIMIT || opCode >= MAX_LIMIT) {
120             return -E_PARSE_FAIL;
121         }
122         opCode_ = static_cast<ExecutorOperation>(opCode);
123 
124         // sql
125         (void)parcel.ReadString(sql_);
126 
127         // bindArgs
128         int argsCount = 0;
129         (void)parcel.ReadInt(argsCount);
130         if (parcel.IsError() || argsCount < 0 || argsCount > static_cast<int>(DBConstant::MAX_SQL_ARGS_COUNT)) {
131             return -E_PARSE_FAIL;
132         }
133         for (int i = 0; i < argsCount; ++i) {
134             std::string bindArg;
135             (void)parcel.ReadString(bindArg);
136             if (parcel.IsError()) {
137                 return -E_PARSE_FAIL;
138             }
139             bindArgs_.emplace_back(std::move(bindArg));
140         }
141     }
142 
143     parcel.EightByteAlign();
144     if (parcel.IsError()) {
145         return -E_PARSE_FAIL;
146     }
147     return E_OK;
148 }
149 }
150 #endif