1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "remote_executor_packet.h"
17 
18 namespace DistributedDB {
19 namespace {
20     constexpr uint8_t REQUEST_FLAG_RESPONSE_ACK = 1u;
21     constexpr uint8_t ACK_FLAG_LAST_ACK = 1u;
22     constexpr uint8_t ACK_FLAG_SECURITY_OPTION = 2u;
23 }
24 
GetVersion() const25 uint32_t RemoteExecutorRequestPacket::GetVersion() const
26 {
27     return version_;
28 }
29 
SetVersion(uint32_t version)30 void RemoteExecutorRequestPacket::SetVersion(uint32_t version)
31 {
32     version_ = version;
33 }
34 
GetFlag() const35 uint32_t RemoteExecutorRequestPacket::GetFlag() const
36 {
37     return flag_;
38 }
39 
SetFlag(uint32_t flag)40 void RemoteExecutorRequestPacket::SetFlag(uint32_t flag)
41 {
42     flag_ = flag;
43 }
44 
GetPreparedStmt() const45 const PreparedStmt &RemoteExecutorRequestPacket::GetPreparedStmt() const
46 {
47     return preparedStmt_;
48 }
49 
IsNeedResponse() const50 bool RemoteExecutorRequestPacket::IsNeedResponse() const
51 {
52     return (flag_ & REQUEST_FLAG_RESPONSE_ACK) != 0;
53 }
54 
SetNeedResponse()55 void RemoteExecutorRequestPacket::SetNeedResponse()
56 {
57     flag_ |= REQUEST_FLAG_RESPONSE_ACK;
58 }
59 
SetExtraConditions(const std::map<std::string,std::string> & extraConditions)60 void RemoteExecutorRequestPacket::SetExtraConditions(const std::map<std::string, std::string> &extraConditions)
61 {
62     extraConditions_ = extraConditions;
63 }
64 
GetExtraConditions() const65 std::map<std::string, std::string> RemoteExecutorRequestPacket::GetExtraConditions() const
66 {
67     return extraConditions_;
68 }
69 
CalculateLen() const70 uint32_t RemoteExecutorRequestPacket::CalculateLen() const
71 {
72     uint32_t len = Parcel::GetUInt32Len(); // version
73     len += Parcel::GetUInt32Len();  // flag
74     uint32_t tmpLen = preparedStmt_.CalcLength();
75     if ((len + tmpLen) > static_cast<uint32_t>(INT32_MAX) || tmpLen == 0u) {
76         LOGE("[RemoteExecutorRequestPacket][CalculateLen] Prepared statement is too large");
77         return 0;
78     }
79     len += tmpLen;
80     len += Parcel::GetUInt32Len(); // conditions count
81     for (const auto &entry : extraConditions_) {
82         // each condition len never greater than 256
83         len += Parcel::GetStringLen(entry.first);
84         len += Parcel::GetStringLen(entry.second);
85         if (len > static_cast<uint32_t>(INT32_MAX)) {
86             LOGE("[RemoteExecutorRequestPacket][CalculateLen] conditions is too large");
87             return 0;
88         }
89     }
90     len = Parcel::GetEightByteAlign(len); // 8-byte align
91     len += Parcel::GetIntLen();
92     return len;
93 }
94 
Serialization(Parcel & parcel) const95 int RemoteExecutorRequestPacket::Serialization(Parcel &parcel) const
96 {
97     (void) parcel.WriteUInt32(version_);
98     (void) parcel.WriteUInt32(flag_);
99     (void) preparedStmt_.Serialize(parcel);
100     if (parcel.IsError()) {
101         LOGE("[RemoteExecutorRequestPacket] Serialization failed");
102         return -E_INVALID_ARGS;
103     }
104     if (extraConditions_.size() > DBConstant::MAX_CONDITION_COUNT) {
105         LOGE("[RemoteExecutorRequestPacket] Serialization failed with too much condition");
106         return -E_INVALID_ARGS;
107     }
108     parcel.WriteUInt32(static_cast<uint32_t>(extraConditions_.size()));
109     for (const auto &entry : extraConditions_) {
110         if (entry.first.length() > DBConstant::MAX_CONDITION_KEY_LEN ||
111             entry.second.length() > DBConstant::MAX_CONDITION_VALUE_LEN) {
112             LOGE("[RemoteExecutorRequestPacket] Serialization failed with too long key or value");
113             return -E_INVALID_ARGS;
114         }
115         parcel.WriteString(entry.first);
116         parcel.WriteString(entry.second);
117     }
118     parcel.EightByteAlign();
119     parcel.WriteInt(secLabel_);
120     if (parcel.IsError()) {
121         return -E_PARSE_FAIL;
122     }
123     return E_OK;
124 }
125 
DeSerialization(Parcel & parcel)126 int RemoteExecutorRequestPacket::DeSerialization(Parcel &parcel)
127 {
128     (void) parcel.ReadUInt32(version_);
129     (void) parcel.ReadUInt32(flag_);
130     (void) preparedStmt_.DeSerialize(parcel);
131     if (parcel.IsError()) {
132         LOGE("[RemoteExecutorRequestPacket] DeSerialization failed");
133         return -E_INVALID_ARGS;
134     }
135     if (version_ < REQUEST_PACKET_VERSION_V2) {
136         return E_OK;
137     }
138     uint32_t conditionSize = 0u;
139     (void) parcel.ReadUInt32(conditionSize);
140     if (conditionSize > DBConstant::MAX_CONDITION_COUNT) {
141         return -E_INVALID_ARGS;
142     }
143     for (uint32_t i = 0; i < conditionSize; i++) {
144         std::string conditionKey;
145         std::string conditionVal;
146         (void) parcel.ReadString(conditionKey);
147         (void) parcel.ReadString(conditionVal);
148         if (conditionKey.length() > DBConstant::MAX_CONDITION_KEY_LEN ||
149             conditionVal.length() > DBConstant::MAX_CONDITION_VALUE_LEN) {
150             return -E_INVALID_ARGS;
151         }
152         extraConditions_[conditionKey] = conditionVal;
153     }
154     parcel.EightByteAlign();
155     if (version_ >= REQUEST_PACKET_VERSION_V3) {
156         parcel.ReadInt(secLabel_);
157     }
158     if (parcel.IsError()) {
159         return -E_PARSE_FAIL;
160     }
161     return E_OK;
162 }
163 
SetOpCode(PreparedStmt::ExecutorOperation opCode)164 void RemoteExecutorRequestPacket::SetOpCode(PreparedStmt::ExecutorOperation opCode)
165 {
166     preparedStmt_.SetOpCode(opCode);
167 }
168 
SetSql(const std::string & sql)169 void RemoteExecutorRequestPacket::SetSql(const std::string &sql)
170 {
171     preparedStmt_.SetSql(sql);
172 }
173 
SetBindArgs(const std::vector<std::string> & bindArgs)174 void RemoteExecutorRequestPacket::SetBindArgs(const std::vector<std::string> &bindArgs)
175 {
176     preparedStmt_.SetBindArgs(bindArgs);
177 }
178 
SetSecLabel(int32_t secLabel)179 void RemoteExecutorRequestPacket::SetSecLabel(int32_t secLabel)
180 {
181     secLabel_ = secLabel;
182 }
183 
GetSecLabel() const184 int32_t RemoteExecutorRequestPacket::GetSecLabel() const
185 {
186     return secLabel_;
187 }
188 
Create()189 RemoteExecutorRequestPacket* RemoteExecutorRequestPacket::Create()
190 {
191     return new (std::nothrow) RemoteExecutorRequestPacket();
192 }
193 
Release(RemoteExecutorRequestPacket * & packet)194 void RemoteExecutorRequestPacket::Release(RemoteExecutorRequestPacket *&packet)
195 {
196     delete packet;
197     packet = nullptr;
198 }
199 
GetVersion() const200 uint32_t RemoteExecutorAckPacket::GetVersion() const
201 {
202     return version_;
203 }
204 
SetVersion(uint32_t version)205 void RemoteExecutorAckPacket::SetVersion(uint32_t version)
206 {
207     version_ = version;
208 }
209 
GetFlag() const210 uint32_t RemoteExecutorAckPacket::GetFlag() const
211 {
212     return flag_;
213 }
214 
SetFlag(uint32_t flag)215 void RemoteExecutorAckPacket::SetFlag(uint32_t flag)
216 {
217     flag_ = flag;
218 }
219 
GetAckCode() const220 int32_t RemoteExecutorAckPacket::GetAckCode() const
221 {
222     return ackCode_;
223 }
224 
SetAckCode(int32_t ackCode)225 void RemoteExecutorAckPacket::SetAckCode(int32_t ackCode)
226 {
227     ackCode_ = ackCode;
228 }
229 
MoveInRowDataSet(RelationalRowDataSet && rowDataSet)230 void RemoteExecutorAckPacket::MoveInRowDataSet(RelationalRowDataSet &&rowDataSet)
231 {
232     rowDataSet_ = std::move(rowDataSet);
233 }
234 
MoveOutRowDataSet() const235 RelationalRowDataSet &&RemoteExecutorAckPacket::MoveOutRowDataSet() const
236 {
237     return std::move(rowDataSet_);
238 }
239 
IsLastAck() const240 bool RemoteExecutorAckPacket::IsLastAck() const
241 {
242     return (flag_ & ACK_FLAG_LAST_ACK) != 0;
243 }
244 
SetLastAck()245 void RemoteExecutorAckPacket::SetLastAck()
246 {
247     flag_ |= ACK_FLAG_LAST_ACK;
248 }
249 
CalculateLen() const250 uint32_t RemoteExecutorAckPacket::CalculateLen() const
251 {
252     uint32_t len = Parcel::GetUInt32Len(); // version
253     len += Parcel::GetIntLen();    // ackCode
254     len += Parcel::GetUInt32Len();  // flag
255     len = Parcel::GetEightByteAlign(len);
256     len += static_cast<uint32_t>(rowDataSet_.CalcLength());
257     len += Parcel::GetIntLen(); // secLabel
258     len += Parcel::GetIntLen(); // secFlag
259     return len;
260 }
261 
Serialization(Parcel & parcel) const262 int RemoteExecutorAckPacket::Serialization(Parcel &parcel) const
263 {
264     (void) parcel.WriteUInt32(version_);
265     (void) parcel.WriteInt(ackCode_);
266     (void) parcel.WriteUInt32(flag_);
267     parcel.EightByteAlign();
268     (void) rowDataSet_.Serialize(parcel);
269     (void) parcel.WriteInt(secLabel_);
270     (void) parcel.WriteInt(secFlag_);
271     if (parcel.IsError()) {
272         LOGE("[RemoteExecutorAckPacket] Serialization failed");
273         return -E_INVALID_ARGS;
274     }
275     return E_OK;
276 }
277 
DeSerialization(Parcel & parcel)278 int RemoteExecutorAckPacket::DeSerialization(Parcel &parcel)
279 {
280     (void) parcel.ReadUInt32(version_);
281     (void) parcel.ReadInt(ackCode_);
282     (void) parcel.ReadUInt32(flag_);
283     parcel.EightByteAlign();
284     if (parcel.IsError()) {
285         LOGE("[RemoteExecutorAckPacket] DeSerialization failed");
286         return -E_INVALID_ARGS;
287     }
288     int errCode = rowDataSet_.DeSerialize(parcel);
289     if (errCode != E_OK) {
290         return errCode;
291     }
292     if ((flag_ & ACK_FLAG_SECURITY_OPTION) != 0) {
293         (void) parcel.ReadInt(secLabel_);
294         (void) parcel.ReadInt(secFlag_);
295     } else {
296         secLabel_ = NOT_SUPPORT_SEC_CLASSIFICATION;
297     }
298     if (parcel.IsError()) {
299         LOGE("[RemoteExecutorAckPacket] DeSerialization failed");
300         return -E_INVALID_ARGS;
301     }
302     return E_OK;
303 }
304 
GetSecurityOption() const305 SecurityOption RemoteExecutorAckPacket::GetSecurityOption() const
306 {
307     SecurityOption option = {secLabel_, secFlag_};
308     return option;
309 }
310 
SetSecurityOption(const SecurityOption & option)311 void RemoteExecutorAckPacket::SetSecurityOption(const SecurityOption &option)
312 {
313     secLabel_ = option.securityLabel;
314     secFlag_ = option.securityFlag;
315     flag_ |= ACK_FLAG_SECURITY_OPTION;
316 }
317 }