1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "protocol_proto.h"
17 #include <iterator>
18 #include <mutex>
19 #include <new>
20 #include "db_common.h"
21 #include "endian_convert.h"
22 #include "hash.h"
23 #include "header_converter.h"
24 #include "log_print.h"
25 #include "macro_utils.h"
26 #include "securec.h"
27 #include "version.h"
28
29 namespace DistributedDB {
30 namespace {
31 const uint16_t MAGIC_CODE = 0xAAAA;
32 const uint16_t PROTOCOL_VERSION = 0;
33 // Compatibility Final Method. 3 Correspond To Version 1.1.4(104)
34 const uint16_t DB_GLOBAL_VERSION = SOFTWARE_VERSION_CURRENT - SOFTWARE_VERSION_EARLIEST;
35 const uint8_t PACKET_TYPE_FRAGMENTED = BITX(0); // Use bit 0
36 const uint8_t PACKET_TYPE_NOT_FRAGMENTED = 0;
37 const uint8_t MAX_PADDING_LEN = 7;
38 const uint32_t LENGTH_BEFORE_SUM_RANGE = sizeof(uint64_t) + sizeof(uint64_t);
39 const uint32_t MAX_FRAME_LEN = 32 * 1024 * 1024; // Max 32 MB, 1024 is scale
40 const uint16_t MIN_FRAGMENT_COUNT = 2; // At least a frame will be splited into 2 parts
41 // LabelExchange(Ack) Frame Field Length
42 const uint32_t LABEL_VER_LEN = sizeof(uint64_t);
43 const uint32_t DISTINCT_VALUE_LEN = sizeof(uint64_t);
44 const uint32_t SEQUENCE_ID_LEN = sizeof(uint64_t);
45 // Note: COMM_LABEL_LENGTH is defined in communicator_type_define.h
46 const uint32_t COMM_LABEL_COUNT_LEN = sizeof(uint64_t);
47 // Local func to set and get frame Type from packet Type field
SetFrameType(FrameType inFrameType,uint8_t & inPacketType)48 void SetFrameType(FrameType inFrameType, uint8_t &inPacketType)
49 {
50 inPacketType &= 0x0F; // Use 0x0F to clear high four bits
51 inPacketType |= (static_cast<uint8_t>(inFrameType) << 4); // frame type is on high 4 bits
52 }
53
GetFrameType(uint8_t inPacketType)54 FrameType GetFrameType(uint8_t inPacketType)
55 {
56 uint8_t frameType = ((inPacketType & 0xF0) >> 4); // Use 0xF0 to get high 4 bits
57 if (frameType >= static_cast<uint8_t>(FrameType::INVALID_MAX_FRAME_TYPE)) {
58 return FrameType::INVALID_MAX_FRAME_TYPE;
59 }
60 return static_cast<FrameType>(frameType);
61 }
62
IsSendLabelExchange(uint8_t inPacketType)63 bool IsSendLabelExchange(uint8_t inPacketType)
64 {
65 return ((inPacketType & 0x08) >> 3) == 0; // Use 0x08 and remove low 3 bit, it is Communication negotiation mark
66 }
67
SetSendLabelExchange(uint8_t & inPacketType,bool sendLabelExchange)68 void SetSendLabelExchange(uint8_t &inPacketType, bool sendLabelExchange)
69 {
70 if (!sendLabelExchange) {
71 inPacketType |= 0x08; // mark 0x08 when not support communication
72 }
73 }
74 }
75
76 std::map<uint32_t, TransformFunc> ProtocolProto::msgIdMapFunc_;
77 std::shared_mutex ProtocolProto::msgIdMutex_;
78
GetAppLayerFrameHeaderLength()79 uint32_t ProtocolProto::GetAppLayerFrameHeaderLength()
80 {
81 uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader);
82 return length;
83 }
84
GetLengthBeforeSerializedData()85 uint32_t ProtocolProto::GetLengthBeforeSerializedData()
86 {
87 uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + sizeof(MessageHeader);
88 return length;
89 }
90
GetCommLayerFrameHeaderLength()91 uint32_t ProtocolProto::GetCommLayerFrameHeaderLength()
92 {
93 uint32_t length = sizeof(CommPhyHeader);
94 return length;
95 }
96
ToSerialBuffer(const Message * inMsg,std::shared_ptr<ExtendHeaderHandle> & extendHandle,bool onlyMsgHeader,int & outErrorNo)97 SerialBuffer *ProtocolProto::ToSerialBuffer(const Message *inMsg,
98 std::shared_ptr<ExtendHeaderHandle> &extendHandle, bool onlyMsgHeader, int &outErrorNo)
99 {
100 if (inMsg == nullptr) {
101 outErrorNo = -E_INVALID_ARGS;
102 return nullptr;
103 }
104
105 uint32_t serializeLen = 0;
106 if (!onlyMsgHeader) {
107 int errCode = CalculateDataSerializeLength(inMsg, serializeLen);
108 if (errCode != E_OK) {
109 outErrorNo = errCode;
110 return nullptr;
111 }
112 }
113 uint32_t headSize = 0;
114 int errCode = GetExtendHeadDataSize(extendHandle, headSize);
115 if (errCode != E_OK) {
116 outErrorNo = errCode;
117 return nullptr;
118 }
119
120 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
121 if (buffer == nullptr) {
122 outErrorNo = -E_OUT_OF_MEMORY;
123 return nullptr;
124 }
125 if (headSize > 0) {
126 buffer->SetExtendHeadLength(headSize);
127 }
128 // serializeLen maybe not 8-bytes aligned, let SerialBuffer deal with the padding.
129 uint32_t payLoadLength = serializeLen + sizeof(MessageHeader);
130 errCode = buffer->AllocBufferByPayloadLength(payLoadLength, GetAppLayerFrameHeaderLength());
131 if (errCode != E_OK) {
132 LOGE("[Proto][ToSerial] Alloc Fail, errCode=%d.", errCode);
133 goto ERROR_HANDLE;
134 }
135 errCode = FillExtendHeadDataIfNeed(extendHandle, buffer, headSize);
136 if (errCode != E_OK) {
137 goto ERROR_HANDLE;
138 }
139
140 // Serialize the MessageHeader and data if need
141 errCode = SerializeMessage(buffer, inMsg);
142 if (errCode != E_OK) {
143 LOGE("[Proto][ToSerial] Serialize Fail, errCode=%d.", errCode);
144 goto ERROR_HANDLE;
145 }
146 outErrorNo = E_OK;
147 return buffer;
148 ERROR_HANDLE:
149 outErrorNo = errCode;
150 delete buffer;
151 buffer = nullptr;
152 return nullptr;
153 }
154
ToMessage(const SerialBuffer * inBuff,int & outErrorNo,bool onlyMsgHeader)155 Message *ProtocolProto::ToMessage(const SerialBuffer *inBuff, int &outErrorNo, bool onlyMsgHeader)
156 {
157 if (inBuff == nullptr) {
158 outErrorNo = -E_INVALID_ARGS;
159 return nullptr;
160 }
161 Message *outMsg = new (std::nothrow) Message();
162 if (outMsg == nullptr) {
163 outErrorNo = -E_OUT_OF_MEMORY;
164 return nullptr;
165 }
166 int errCode = DeSerializeMessage(inBuff, outMsg, onlyMsgHeader);
167 if (errCode != E_OK && errCode != -E_NOT_REGISTER) {
168 LOGE("[Proto][ToMessage] DeSerialize Fail, errCode=%d.", errCode);
169 outErrorNo = errCode;
170 delete outMsg;
171 outMsg = nullptr;
172 return nullptr;
173 }
174 // If messageId not register in this software version, we return errCode and the Message without an object.
175 outErrorNo = errCode;
176 return outMsg;
177 }
178
BuildEmptyFrameForVersionNegotiate(int & outErrorNo)179 SerialBuffer *ProtocolProto::BuildEmptyFrameForVersionNegotiate(int &outErrorNo)
180 {
181 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
182 if (buffer == nullptr) {
183 outErrorNo = -E_OUT_OF_MEMORY;
184 return nullptr;
185 }
186
187 // Empty frame has no payload, only header
188 int errCode = buffer->AllocBufferByPayloadLength(0, GetCommLayerFrameHeaderLength());
189 if (errCode != E_OK) {
190 LOGE("[Proto][BuildEmpty] Alloc Fail, errCode=%d.", errCode);
191 outErrorNo = errCode;
192 delete buffer;
193 buffer = nullptr;
194 return nullptr;
195 }
196 outErrorNo = E_OK;
197 return buffer;
198 }
199
BuildFeedbackMessageFrame(const Message * inMsg,const LabelType & inLabel,int & outErrorNo)200 SerialBuffer *ProtocolProto::BuildFeedbackMessageFrame(const Message *inMsg, const LabelType &inLabel,
201 int &outErrorNo)
202 {
203 std::shared_ptr<ExtendHeaderHandle> extendHandle = nullptr;
204 SerialBuffer *buffer = ToSerialBuffer(inMsg, extendHandle, true, outErrorNo);
205 if (buffer == nullptr) {
206 // outErrorNo had already been set in ToSerialBuffer
207 return nullptr;
208 }
209 int errCode = ProtocolProto::SetDivergeHeader(buffer, inLabel);
210 if (errCode != E_OK) {
211 LOGE("[Proto][BuildFeedback] Set DivergeHeader fail, label=%.3s, errCode=%d.", VEC_TO_STR(inLabel), errCode);
212 outErrorNo = errCode;
213 delete buffer;
214 buffer = nullptr;
215 return nullptr;
216 }
217 outErrorNo = E_OK;
218 return buffer;
219 }
220
BuildLabelExchange(uint64_t inDistinctValue,uint64_t inSequenceId,const std::set<LabelType> & inLabels,int & outErrorNo)221 SerialBuffer *ProtocolProto::BuildLabelExchange(uint64_t inDistinctValue, uint64_t inSequenceId,
222 const std::set<LabelType> &inLabels, int &outErrorNo)
223 {
224 // Size of inLabels won't be too large.
225 // The upper layer code(inside this communicator module) guarantee that size of each Label equals COMM_LABEL_LENGTH
226 uint64_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
227 inLabels.size() * COMM_LABEL_LENGTH;
228 if (payloadLen > INT32_MAX) {
229 outErrorNo = -E_INVALID_ARGS;
230 return nullptr;
231 }
232 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
233 if (buffer == nullptr) {
234 outErrorNo = -E_OUT_OF_MEMORY;
235 return nullptr;
236 }
237 int errCode = buffer->AllocBufferByPayloadLength(static_cast<uint32_t>(payloadLen),
238 GetCommLayerFrameHeaderLength());
239 if (errCode != E_OK) {
240 LOGE("[Proto][BuildLabel] Alloc Fail, errCode=%d.", errCode);
241 outErrorNo = errCode;
242 delete buffer;
243 buffer = nullptr;
244 return nullptr;
245 }
246
247 auto payloadByteLen = buffer->GetWritableBytesForPayload();
248 auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
249 *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
250 fieldPtr++;
251 *fieldPtr = HostToNet(inDistinctValue);
252 fieldPtr++;
253 *fieldPtr = HostToNet(inSequenceId);
254 fieldPtr++;
255 *fieldPtr = HostToNet(static_cast<uint64_t>(inLabels.size()));
256 fieldPtr++;
257 // Note: don't worry, memory length had been carefully calculated above
258 auto bytePtr = reinterpret_cast<uint8_t *>(fieldPtr);
259 for (const auto &eachLabel : inLabels) {
260 for (const auto &eachByte : eachLabel) {
261 *bytePtr++ = eachByte;
262 }
263 }
264 outErrorNo = E_OK;
265 return buffer;
266 }
267
BuildLabelExchangeAck(uint64_t inDistinctValue,uint64_t inSequenceId,int & outErrorNo)268 SerialBuffer *ProtocolProto::BuildLabelExchangeAck(uint64_t inDistinctValue, uint64_t inSequenceId, int &outErrorNo)
269 {
270 uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN;
271 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
272 if (buffer == nullptr) {
273 outErrorNo = -E_OUT_OF_MEMORY;
274 return nullptr;
275 }
276 int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
277 if (errCode != E_OK) {
278 LOGE("[Proto][BuildLabelAck] Alloc Fail, errCode=%d.", errCode);
279 outErrorNo = errCode;
280 delete buffer;
281 buffer = nullptr;
282 return nullptr;
283 }
284
285 auto payloadByteLen = buffer->GetWritableBytesForPayload();
286 auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
287 *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
288 fieldPtr++;
289 *fieldPtr = HostToNet(inDistinctValue);
290 fieldPtr++;
291 *fieldPtr = HostToNet(inSequenceId);
292 fieldPtr++;
293 outErrorNo = E_OK;
294 return buffer;
295 }
296
SplitFrameIntoPacketsIfNeed(const SerialBuffer * inBuff,uint32_t inMtuSize,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)297 int ProtocolProto::SplitFrameIntoPacketsIfNeed(const SerialBuffer *inBuff, uint32_t inMtuSize,
298 std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
299 {
300 auto bufferBytesLen = inBuff->GetReadOnlyBytesForEntireBuffer();
301 if ((bufferBytesLen.second + inBuff->GetExtendHeadLength()) <= inMtuSize) {
302 return E_OK;
303 }
304 uint32_t modifyMtuSize = inMtuSize - inBuff->GetExtendHeadLength();
305 // Do Fragmentaion! This function aims at calculate how many fragments to be split into.
306 auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame(); // Padding not in the range of fragmentation.
307 uint32_t lengthToSplit = frameBytesLen.second - sizeof(CommPhyHeader); // The former is always larger than latter.
308 // The inMtuSize pass from CommunicatorAggregator is large enough to be subtract by the latter two.
309 uint32_t maxFragmentLen = modifyMtuSize - sizeof(CommPhyHeader) - sizeof(CommPhyOptHeader);
310 // It can be proved that lengthToSplit is always larger than maxFragmentLen, so quotient won't be zero.
311 // The maxFragmentLen won't be zero and in fact large enough to make sure no precision loss during division
312 uint16_t quotient = lengthToSplit / maxFragmentLen;
313 uint32_t remainder = lengthToSplit % maxFragmentLen;
314 // Finally we get the fragCount for this frame
315 uint16_t fragCount = ((remainder == 0) ? quotient : (quotient + 1));
316 // Get CommPhyHeader of this frame to be modified for each packets (Header in network endian)
317 auto oriPhyHeader = reinterpret_cast<const CommPhyHeader *>(frameBytesLen.first);
318 FrameFragmentInfo fragInfo = {inBuff->GetOringinalAddr(), inBuff->GetExtendHeadLength(), lengthToSplit, fragCount};
319 return FrameFragmentation(frameBytesLen.first + sizeof(CommPhyHeader), fragInfo, *oriPhyHeader, outPieces);
320 }
321
AnalyzeSplitStructure(const ParseResult & inResult,uint32_t & outFragLen,uint32_t & outLastFragLen)322 int ProtocolProto::AnalyzeSplitStructure(const ParseResult &inResult, uint32_t &outFragLen, uint32_t &outLastFragLen)
323 {
324 uint32_t frameLen = inResult.GetFrameLen();
325 uint16_t fragCount = inResult.GetFragCount();
326 uint16_t fragNo = inResult.GetFragNo();
327
328 // Firstly: Check frameLen
329 if (frameLen <= sizeof(CommPhyHeader) || frameLen > MAX_FRAME_LEN) {
330 LOGE("[Proto][ParsePhyOpt] FrameLen=%" PRIu32 " illegal.", frameLen);
331 return -E_PARSE_FAIL;
332 }
333
334 // Secondly: Check fragCount and fragNo
335 uint32_t lengthBeSplit = frameLen - sizeof(CommPhyHeader);
336 if (fragCount == 0 || fragCount < MIN_FRAGMENT_COUNT || fragCount > lengthBeSplit || fragNo >= fragCount) {
337 LOGE("[Proto][ParsePhyOpt] FragCount=%" PRIu32 " or fragNo=%" PRIu32 " illegal.", fragCount, fragNo);
338 return -E_PARSE_FAIL;
339 }
340
341 // Finally: Check length relation deeply
342 uint32_t quotient = lengthBeSplit / fragCount;
343 uint16_t remainder = lengthBeSplit % fragCount;
344 outFragLen = quotient;
345 outLastFragLen = quotient + remainder;
346 uint32_t thisFragLen = ((fragNo != fragCount - 1) ? outFragLen : outLastFragLen); // subtract by 1 for index
347 if ((sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + thisFragLen +
348 inResult.GetPaddingLen()) != inResult.GetPacketLen()) {
349 LOGE("[Proto][ParsePhyOpt] Length Error: FrameLen=%" PRIu32 ", FragCount=%" PRIu32 ", fragNo=%" PRIu32
350 ", PaddingLen=%" PRIu32 ", PacketLen=%" PRIu32, frameLen, fragCount, fragNo, inResult.GetPaddingLen(),
351 inResult.GetPacketLen());
352 return -E_PARSE_FAIL;
353 }
354
355 return E_OK;
356 }
357
CombinePacketIntoFrame(SerialBuffer * inFrame,const uint8_t * pktBytes,uint32_t pktLength,uint32_t fragOffset,uint32_t fragLength)358 int ProtocolProto::CombinePacketIntoFrame(SerialBuffer *inFrame, const uint8_t *pktBytes, uint32_t pktLength,
359 uint32_t fragOffset, uint32_t fragLength)
360 {
361 // inFrame is the destination, pktBytes and pktLength are the source, fragOffset and fragLength give the boundary
362 // Firstly: Check the length relation of source, even this check is not supposed to fail
363 if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + fragLength > pktLength) {
364 return -E_LENGTH_ERROR;
365 }
366 // Secondly: Check the length relation of destination, even this check is not supposed to fail
367 auto frameByteLen = inFrame->GetWritableBytesForEntireFrame();
368 if (sizeof(CommPhyHeader) + fragOffset + fragLength > frameByteLen.second) {
369 return -E_LENGTH_ERROR;
370 }
371 // Finally: Do Combination!
372 const uint8_t *srcByteHead = pktBytes + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
373 uint8_t *dstByteHead = frameByteLen.first + sizeof(CommPhyHeader) + fragOffset;
374 uint32_t dstLeftLen = frameByteLen.second - sizeof(CommPhyHeader) - fragOffset;
375 errno_t errCode = memcpy_s(dstByteHead, dstLeftLen, srcByteHead, fragLength);
376 if (errCode != EOK) {
377 return -E_SECUREC_ERROR;
378 }
379 return E_OK;
380 }
381
RegTransformFunction(uint32_t msgId,const TransformFunc & inFunc)382 int ProtocolProto::RegTransformFunction(uint32_t msgId, const TransformFunc &inFunc)
383 {
384 std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
385 if (msgIdMapFunc_.count(msgId) != 0) {
386 return -E_ALREADY_REGISTER;
387 }
388 if (!inFunc.computeFunc || !inFunc.serializeFunc || !inFunc.deserializeFunc) {
389 return -E_INVALID_ARGS;
390 }
391 msgIdMapFunc_[msgId] = inFunc;
392 return E_OK;
393 }
394
UnRegTransformFunction(uint32_t msgId)395 void ProtocolProto::UnRegTransformFunction(uint32_t msgId)
396 {
397 std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
398 if (msgIdMapFunc_.count(msgId) != 0) {
399 msgIdMapFunc_.erase(msgId);
400 }
401 }
402
SetDivergeHeader(SerialBuffer * inBuff,const LabelType & inCommLabel)403 int ProtocolProto::SetDivergeHeader(SerialBuffer *inBuff, const LabelType &inCommLabel)
404 {
405 if (inBuff == nullptr) {
406 return -E_INVALID_ARGS;
407 }
408 auto headerByteLen = inBuff->GetWritableBytesForHeader();
409 if (headerByteLen.second != GetAppLayerFrameHeaderLength()) {
410 return -E_INVALID_ARGS;
411 }
412 auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
413
414 CommDivergeHeader divergeHeader;
415 divergeHeader.version = PROTOCOL_VERSION;
416 divergeHeader.reserved = 0;
417 divergeHeader.payLoadLen = payloadByteLen.second;
418 // The upper layer code(inside this communicator module) guarantee that size of inCommLabel equal COMM_LABEL_LENGTH
419 for (unsigned int i = 0; i < COMM_LABEL_LENGTH; i++) {
420 divergeHeader.commLabel[i] = inCommLabel[i];
421 }
422 HeaderConverter::ConvertHostToNet(divergeHeader, divergeHeader);
423
424 errno_t errCode = memcpy_s(headerByteLen.first + sizeof(CommPhyHeader),
425 headerByteLen.second - sizeof(CommPhyHeader), &divergeHeader, sizeof(CommDivergeHeader));
426 if (errCode != EOK) {
427 return -E_SECUREC_ERROR;
428 }
429 return E_OK;
430 }
431
432 namespace {
FillPhyHeaderLenInfo(uint32_t packetLen,uint64_t sum,uint8_t type,uint8_t paddingLen,CommPhyHeader & header)433 void FillPhyHeaderLenInfo(uint32_t packetLen, uint64_t sum, uint8_t type, uint8_t paddingLen, CommPhyHeader &header)
434 {
435 header.packetLen = packetLen;
436 header.checkSum = sum;
437 header.packetType |= type;
438 header.paddingLen = paddingLen;
439 }
440 }
441
SetPhyHeader(SerialBuffer * inBuff,const PhyHeaderInfo & inInfo)442 int ProtocolProto::SetPhyHeader(SerialBuffer *inBuff, const PhyHeaderInfo &inInfo)
443 {
444 if (inBuff == nullptr) {
445 return -E_INVALID_ARGS;
446 }
447 auto headerByteLen = inBuff->GetWritableBytesForHeader();
448 if (headerByteLen.second < sizeof(CommPhyHeader)) {
449 return -E_INVALID_ARGS;
450 }
451 auto bufferByteLen = inBuff->GetReadOnlyBytesForEntireBuffer();
452 auto frameByteLen = inBuff->GetReadOnlyBytesForEntireFrame();
453
454 uint32_t packetLen = bufferByteLen.second;
455 uint8_t paddingLen = static_cast<uint8_t>(bufferByteLen.second - frameByteLen.second);
456 uint8_t packetType = PACKET_TYPE_NOT_FRAGMENTED;
457 if (inInfo.frameType != FrameType::INVALID_MAX_FRAME_TYPE) {
458 SetFrameType(inInfo.frameType, packetType);
459 } else {
460 return -E_INVALID_ARGS;
461 }
462 SetSendLabelExchange(packetType, inInfo.sendLabelExchange);
463
464 CommPhyHeader phyHeader;
465 phyHeader.magic = MAGIC_CODE;
466 phyHeader.version = PROTOCOL_VERSION;
467 phyHeader.sourceId = inInfo.sourceId;
468 phyHeader.frameId = inInfo.frameId;
469 phyHeader.packetType = 0;
470 phyHeader.dbIntVer = DB_GLOBAL_VERSION;
471 FillPhyHeaderLenInfo(packetLen, 0, packetType, paddingLen, phyHeader); // Sum is calculated afterwards
472 HeaderConverter::ConvertHostToNet(phyHeader, phyHeader);
473
474 errno_t retCode = memcpy_s(headerByteLen.first, headerByteLen.second, &phyHeader, sizeof(CommPhyHeader));
475 if (retCode != EOK) {
476 return -E_SECUREC_ERROR;
477 }
478
479 uint64_t sumResult = 0;
480 int errCode = CalculateXorSum(bufferByteLen.first + LENGTH_BEFORE_SUM_RANGE,
481 bufferByteLen.second - LENGTH_BEFORE_SUM_RANGE, sumResult);
482 if (errCode != E_OK) {
483 return -E_SUM_CALCULATE_FAIL;
484 }
485
486 auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(headerByteLen.first);
487 ptrPhyHeader->checkSum = HostToNet(sumResult);
488
489 return E_OK;
490 }
491
CheckAndParsePacket(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & outResult)492 int ProtocolProto::CheckAndParsePacket(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
493 ParseResult &outResult)
494 {
495 if (bytes == nullptr || length > MAX_TOTAL_LEN) {
496 return -E_INVALID_ARGS;
497 }
498 int errCode = ParseCommPhyHeader(srcTarget, bytes, length, outResult);
499 if (errCode != E_OK) {
500 LOGE("[Proto][ParsePacket] Parse PhyHeader Fail, errCode=%d.", errCode);
501 return errCode;
502 }
503
504 if (outResult.GetFrameTypeInfo() == FrameType::EMPTY) {
505 return E_OK; // Do nothing more for empty frame
506 }
507
508 if (outResult.IsFragment()) {
509 errCode = ParseCommPhyOptHeader(bytes, length, outResult);
510 if (errCode != E_OK) {
511 LOGE("[Proto][ParsePacket] Parse CommPhyOptHeader Fail, errCode=%d.", errCode);
512 }
513 } else if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
514 errCode = ParseCommLayerPayload(bytes, length, outResult);
515 if (errCode != E_OK) {
516 LOGE("[Proto][ParsePacket] Parse CommLayerPayload Fail, errCode=%d.", errCode);
517 }
518 } else {
519 errCode = ParseCommDivergeHeader(bytes, length, outResult);
520 if (errCode != E_OK) {
521 LOGE("[Proto][ParsePacket] Parse DivergeHeader Fail, errCode=%d.", errCode);
522 }
523 }
524 return errCode;
525 }
526
CheckAndParseFrame(const SerialBuffer * inBuff,ParseResult & outResult)527 int ProtocolProto::CheckAndParseFrame(const SerialBuffer *inBuff, ParseResult &outResult)
528 {
529 if (inBuff == nullptr || outResult.IsFragment()) {
530 return -E_INTERNAL_ERROR;
531 }
532 auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame();
533 if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
534 int errCode = ParseCommLayerPayload(frameBytesLen.first, frameBytesLen.second, outResult);
535 if (errCode != E_OK) {
536 LOGE("[Proto][ParseFrame] Parse CommLayerPayload Fail, errCode=%d.", errCode);
537 return errCode;
538 }
539 } else {
540 int errCode = ParseCommDivergeHeader(frameBytesLen.first, frameBytesLen.second, outResult);
541 if (errCode != E_OK) {
542 LOGE("[Proto][ParseFrame] Parse DivergeHeader Fail, errCode=%d.", errCode);
543 return errCode;
544 }
545 }
546 return E_OK;
547 }
548
DisplayPacketInformation(const uint8_t * bytes,uint32_t length)549 void ProtocolProto::DisplayPacketInformation(const uint8_t *bytes, uint32_t length)
550 {
551 static const char *frameTypeStr[] = {
552 "EmptyFrame",
553 "AppLayerFrame",
554 "CommLayerFrame_LabelExchange",
555 "CommLayerFrame_LabelExchangeAck"
556 };
557
558 if (length < sizeof(CommPhyHeader)) {
559 return;
560 }
561 auto phyHeader = reinterpret_cast<const CommPhyHeader *>(bytes);
562 uint32_t frameId = NetToHost(phyHeader->frameId);
563 uint8_t pktType = NetToHost(phyHeader->packetType);
564 bool isFragment = ((pktType & PACKET_TYPE_FRAGMENTED) != 0);
565 FrameType frameType = GetFrameType(pktType);
566 if (frameType >= FrameType::INVALID_MAX_FRAME_TYPE) {
567 LOGW("[Proto][Display] This is unrecognized frame, pktType=%" PRIu8 ".", pktType);
568 return;
569 }
570 if (isFragment) {
571 if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
572 return;
573 }
574 auto phyOpt = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
575 LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ", frameLen=%" PRIu32 ", fragCount=%" PRIu32
576 ", fragNo=%" PRIu32 ".", frameTypeStr[static_cast<int32_t>(frameType)],
577 frameId, NetToHost(phyOpt->frameLen),
578 NetToHost(phyOpt->fragCount), NetToHost(phyOpt->fragNo));
579 } else {
580 LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ".",
581 frameTypeStr[static_cast<int32_t>(frameType)], frameId);
582 }
583 }
584
CalculateXorSum(const uint8_t * bytes,uint32_t length,uint64_t & outSum)585 int ProtocolProto::CalculateXorSum(const uint8_t *bytes, uint32_t length, uint64_t &outSum)
586 {
587 if ((length > INT32_MAX) || (length % sizeof(uint64_t) != 0)) {
588 LOGE("[Proto][CalcuXorSum] Length=%d not multiple of eight or larget than int32_max.", length);
589 return -E_LENGTH_ERROR;
590 }
591 int count = length / sizeof(uint64_t);
592 auto array = reinterpret_cast<const uint64_t *>(bytes);
593 outSum = 0;
594 for (int i = 0; i < count; i++) {
595 outSum ^= array[i];
596 }
597 return E_OK;
598 }
599
CalculateDataSerializeLength(const Message * inMsg,uint32_t & outLength)600 int ProtocolProto::CalculateDataSerializeLength(const Message *inMsg, uint32_t &outLength)
601 {
602 uint32_t messageId = inMsg->GetMessageId();
603 TransformFunc function;
604 if (GetTransformFunc(messageId, function) != E_OK) {
605 LOGE("[Proto][CalcuDataSerialLen] Not registered for messageId=%" PRIu32 ".", messageId);
606 return -E_NOT_REGISTER;
607 }
608
609 uint32_t serializeLen = function.computeFunc(inMsg);
610 uint32_t alignedLen = BYTE_8_ALIGN(serializeLen);
611 // Currently not allowed the upper module to send a message without data. Regard serializeLen zero as abnormal.
612 if (serializeLen == 0 || alignedLen > MAX_FRAME_LEN - GetLengthBeforeSerializedData()) {
613 LOGE("[Proto][CalcuDataSerialLen] Length too large, msgId=%" PRIu32 ", serializeLen=%" PRIu32
614 ", alignedLen=%" PRIu32 ".", messageId, serializeLen, alignedLen);
615 return -E_LENGTH_ERROR;
616 }
617 // Attention: return the serializeLen nor the alignedLen. Let SerialBuffer to deal with the padding
618 outLength = serializeLen;
619 return E_OK;
620 }
621
SerializeMessage(SerialBuffer * inBuff,const Message * inMsg)622 int ProtocolProto::SerializeMessage(SerialBuffer *inBuff, const Message *inMsg)
623 {
624 auto payloadByteLen = inBuff->GetWritableBytesForPayload();
625 if (payloadByteLen.second < sizeof(MessageHeader)) { // For equal, only msgHeader case
626 LOGE("[Proto][Serialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
627 return -E_LENGTH_ERROR;
628 }
629 uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
630
631 auto messageHdr = reinterpret_cast<MessageHeader *>(payloadByteLen.first);
632 messageHdr->version = inMsg->GetVersion();
633 messageHdr->messageType = inMsg->GetMessageType();
634 messageHdr->messageId = inMsg->GetMessageId();
635 messageHdr->sessionId = inMsg->GetSessionId();
636 messageHdr->sequenceId = inMsg->GetSequenceId();
637 messageHdr->errorNo = inMsg->GetErrorNo();
638 messageHdr->dataLen = dataLen;
639 HeaderConverter::ConvertHostToNet(*messageHdr, *messageHdr);
640
641 if (dataLen == 0) {
642 // For zero dataLen, we don't need to serialize data part
643 return E_OK;
644 }
645 // If dataLen not zero, the TransformFunc of this messageId must exist, the caller's logic guarantee it
646 TransformFunc function;
647 if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
648 LOGE("[Proto][Serialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
649 return -E_NOT_REGISTER;
650 }
651 int result = function.serializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
652 if (result != E_OK) {
653 LOGE("[Proto][Serialize] SerializeFunc Fail, result=%d.", result);
654 return -E_SERIALIZE_ERROR;
655 }
656 return E_OK;
657 }
658
DeSerializeMessage(const SerialBuffer * inBuff,Message * inMsg,bool onlyMsgHeader)659 int ProtocolProto::DeSerializeMessage(const SerialBuffer *inBuff, Message *inMsg, bool onlyMsgHeader)
660 {
661 auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
662 // Check version before parse field
663 if (payloadByteLen.second < sizeof(uint16_t)) {
664 return -E_LENGTH_ERROR;
665 }
666 uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(payloadByteLen.first)));
667 if (!IsSupportMessageVersion(version)) {
668 LOGE("[Proto][DeSerialize] Version=%" PRIu32 " not support.", version);
669 return -E_VERSION_NOT_SUPPORT;
670 }
671
672 if (payloadByteLen.second < sizeof(MessageHeader)) {
673 LOGE("[Proto][DeSerialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
674 return -E_LENGTH_ERROR;
675 }
676 auto oriMsgHeader = reinterpret_cast<const MessageHeader *>(payloadByteLen.first);
677 MessageHeader messageHdr;
678 HeaderConverter::ConvertNetToHost(*oriMsgHeader, messageHdr);
679 inMsg->SetVersion(version);
680 inMsg->SetMessageType(messageHdr.messageType);
681 inMsg->SetMessageId(messageHdr.messageId);
682 inMsg->SetSessionId(messageHdr.sessionId);
683 inMsg->SetSequenceId(messageHdr.sequenceId);
684 inMsg->SetErrorNo(messageHdr.errorNo);
685 uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
686 if (dataLen != messageHdr.dataLen) {
687 LOGE("[Proto][DeSerialize] dataLen=%" PRIu32 ", msgDataLen=%" PRIu32 ".", dataLen, messageHdr.dataLen);
688 return -E_LENGTH_ERROR;
689 }
690 // It is better to check FeedbackMessage first and check onlyMsgHeader flag later
691 if (IsFeedbackErrorMessage(messageHdr.errorNo)) {
692 LOGI("[Proto][DeSerialize] Feedback Message with errorNo=%" PRIu32 ".", messageHdr.errorNo);
693 return E_OK;
694 }
695 if (onlyMsgHeader || dataLen == 0) { // Do not need to deserialize data
696 return E_OK;
697 }
698 TransformFunc function;
699 if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
700 LOGE("[Proto][DeSerialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
701 return -E_NOT_REGISTER;
702 }
703 int result = function.deserializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
704 if (result != E_OK) {
705 LOGE("[Proto][DeSerialize] DeserializeFunc Fail, result=%d.", result);
706 return -E_DESERIALIZE_ERROR;
707 }
708 return E_OK;
709 }
710
IsSupportMessageVersion(uint16_t version)711 bool ProtocolProto::IsSupportMessageVersion(uint16_t version)
712 {
713 return (version == MSG_VERSION_BASE || version == MSG_VERSION_EXT);
714 }
715
IsFeedbackErrorMessage(uint32_t errorNo)716 bool ProtocolProto::IsFeedbackErrorMessage(uint32_t errorNo)
717 {
718 return (errorNo == E_FEEDBACK_UNKNOWN_MESSAGE || errorNo == E_FEEDBACK_COMMUNICATOR_NOT_FOUND);
719 }
720
ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t * bytes,uint32_t length)721 int ProtocolProto::ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t *bytes, uint32_t length)
722 {
723 // At least magic and version should exist
724 if (length < sizeof(uint16_t) + sizeof(uint16_t)) {
725 LOGE("[Proto][ParsePhyCheckVer] Length of Bytes Error.");
726 return -E_LENGTH_ERROR;
727 }
728 auto fieldPtr = reinterpret_cast<const uint16_t *>(bytes);
729 uint16_t magic = NetToHost(*fieldPtr++);
730 uint16_t version = NetToHost(*fieldPtr++);
731
732 if (magic != MAGIC_CODE) {
733 LOGE("[Proto][ParsePhyCheckVer] MagicCode=%" PRIu32 " Error.", magic);
734 return -E_PARSE_FAIL;
735 }
736 if (version != PROTOCOL_VERSION) {
737 LOGE("[Proto][ParsePhyCheckVer] Version=%" PRIu32 " Error.", version);
738 return -E_VERSION_NOT_SUPPORT;
739 }
740 return E_OK;
741 }
742
ParseCommPhyHeaderCheckField(const std::string & srcTarget,const CommPhyHeader & phyHeader,const uint8_t * bytes,uint32_t length)743 int ProtocolProto::ParseCommPhyHeaderCheckField(const std::string &srcTarget, const CommPhyHeader &phyHeader,
744 const uint8_t *bytes, uint32_t length)
745 {
746 if (phyHeader.packetLen != length) {
747 LOGE("[Proto][ParsePhyCheck] PacketLen=%" PRIu32 " Mismatch length=%" PRIu32 ".", phyHeader.packetLen, length);
748 return -E_PARSE_FAIL;
749 }
750 if (phyHeader.paddingLen > MAX_PADDING_LEN) {
751 LOGE("[Proto][ParsePhyCheck] PaddingLen=%" PRIu32 " Error.", phyHeader.paddingLen);
752 return -E_PARSE_FAIL;
753 }
754 if (sizeof(CommPhyHeader) + phyHeader.paddingLen > phyHeader.packetLen) {
755 LOGE("[Proto][ParsePhyCheck] PaddingLen Add PhyHeader Greater Than PacketLen.");
756 return -E_PARSE_FAIL;
757 }
758 uint64_t sumResult = 0;
759 int errCode = CalculateXorSum(bytes + LENGTH_BEFORE_SUM_RANGE, length - LENGTH_BEFORE_SUM_RANGE, sumResult);
760 if (errCode != E_OK) {
761 LOGE("[Proto][ParsePhyCheck] Calculate Sum Fail.");
762 return -E_SUM_CALCULATE_FAIL;
763 }
764 if (phyHeader.checkSum != sumResult) {
765 LOGE("[Proto][ParsePhyCheck] Sum Mismatch, checkSum=%" PRIu64 ", sumResult=%" PRIu64 ".",
766 ULL(phyHeader.checkSum), ULL(sumResult));
767 return -E_SUM_MISMATCH;
768 }
769 return E_OK;
770 }
771
ParseCommPhyHeader(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & inResult)772 int ProtocolProto::ParseCommPhyHeader(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
773 ParseResult &inResult)
774 {
775 int errCode = ParseCommPhyHeaderCheckMagicAndVersion(bytes, length);
776 if (errCode != E_OK) {
777 LOGE("[Proto][ParsePhy] Check Magic And Version Fail.");
778 return errCode;
779 }
780
781 if (length < sizeof(CommPhyHeader)) {
782 LOGE("[Proto][ParsePhy] Length of Bytes Error.");
783 return -E_PARSE_FAIL;
784 }
785 auto phyHeaderOri = reinterpret_cast<const CommPhyHeader *>(bytes);
786 CommPhyHeader phyHeader;
787 HeaderConverter::ConvertNetToHost(*phyHeaderOri, phyHeader);
788 errCode = ParseCommPhyHeaderCheckField(srcTarget, phyHeader, bytes, length);
789 if (errCode != E_OK) {
790 LOGE("[Proto][ParsePhy] Check Field Fail.");
791 return errCode;
792 }
793
794 inResult.SetFrameId(phyHeader.frameId);
795 inResult.SetSourceId(phyHeader.sourceId);
796 inResult.SetPacketLen(phyHeader.packetLen);
797 inResult.SetPaddingLen(phyHeader.paddingLen);
798 inResult.SetDbVersion(phyHeader.dbIntVer);
799 if ((phyHeader.packetType & PACKET_TYPE_FRAGMENTED) != 0) {
800 inResult.SetFragmentFlag(true);
801 } // FragmentFlag default is false
802 FrameType frameType = GetFrameType(phyHeader.packetType);
803 if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
804 LOGW("[Proto][ParsePhy] Unrecognized frame, pktType=%" PRIu32 ".", phyHeader.packetType);
805 return -E_FRAME_TYPE_NOT_SUPPORT;
806 }
807 inResult.SetFrameTypeInfo(frameType);
808 inResult.SetSendLabelExchange(IsSendLabelExchange(phyHeader.packetType));
809 return E_OK;
810 }
811
ParseCommPhyOptHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)812 int ProtocolProto::ParseCommPhyOptHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
813 {
814 if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
815 LOGE("[Proto][ParsePhyOpt] Length of Bytes Error.");
816 return -E_LENGTH_ERROR;
817 }
818 auto headerOri = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
819 CommPhyOptHeader phyOptHeader;
820 HeaderConverter::ConvertNetToHost(*headerOri, phyOptHeader);
821
822 // Check of CommPhyOptHeader field will be done in the procedure of FrameCombiner
823 inResult.SetFrameLen(phyOptHeader.frameLen);
824 inResult.SetFragCount(phyOptHeader.fragCount);
825 inResult.SetFragNo(phyOptHeader.fragNo);
826 return E_OK;
827 }
828
ParseCommDivergeHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)829 int ProtocolProto::ParseCommDivergeHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
830 {
831 // Check version before parse field
832 if (length < sizeof(CommPhyHeader) + sizeof(uint16_t)) {
833 return -E_LENGTH_ERROR;
834 }
835 uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(bytes + sizeof(CommPhyHeader))));
836 if (version != PROTOCOL_VERSION) {
837 LOGE("[Proto][ParseDiverge] Version=%" PRIu16 " not support.", version);
838 return -E_VERSION_NOT_SUPPORT;
839 }
840
841 if (length < sizeof(CommPhyHeader) + sizeof(CommDivergeHeader)) {
842 LOGE("[Proto][ParseDiverge] Length of Bytes Error.");
843 return -E_PARSE_FAIL;
844 }
845 auto headerOri = reinterpret_cast<const CommDivergeHeader *>(bytes + sizeof(CommPhyHeader));
846 CommDivergeHeader divergeHeader;
847 HeaderConverter::ConvertNetToHost(*headerOri, divergeHeader);
848 if (sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + divergeHeader.payLoadLen +
849 inResult.GetPaddingLen() != inResult.GetPacketLen()) {
850 LOGE("[Proto][ParseDiverge] Total Length Mismatch.");
851 return -E_PARSE_FAIL;
852 }
853 inResult.SetPayloadLen(divergeHeader.payLoadLen);
854 inResult.SetCommLabel(LabelType(std::begin(divergeHeader.commLabel), std::end(divergeHeader.commLabel)));
855 return E_OK;
856 }
857
ParseCommLayerPayload(const uint8_t * bytes,uint32_t length,ParseResult & inResult)858 int ProtocolProto::ParseCommLayerPayload(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
859 {
860 if (inResult.GetFrameTypeInfo() == FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK) {
861 int errCode = ParseLabelExchangeAck(bytes, length, inResult);
862 if (errCode != E_OK) {
863 LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
864 return errCode;
865 }
866 } else {
867 int errCode = ParseLabelExchange(bytes, length, inResult);
868 if (errCode != E_OK) {
869 LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
870 return errCode;
871 }
872 }
873 return E_OK;
874 }
875
ParseLabelExchange(const uint8_t * bytes,uint32_t length,ParseResult & inResult)876 int ProtocolProto::ParseLabelExchange(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
877 {
878 // Check version at very first
879 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
880 return -E_LENGTH_ERROR;
881 }
882 auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
883 uint64_t version = NetToHost(*fieldPtr++);
884 if (version != PROTOCOL_VERSION) {
885 LOGE("[Proto][ParseLabel] Version=%" PRIu64 " not support.", ULL(version));
886 return -E_VERSION_NOT_SUPPORT;
887 }
888
889 // Version, DistinctValue, SequenceId and CommLabelCount field must be exist.
890 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN) {
891 LOGE("[Proto][ParseLabel] Length of Bytes Error.");
892 return -E_LENGTH_ERROR;
893 }
894 uint64_t distinctValue = NetToHost(*fieldPtr++);
895 inResult.SetLabelExchangeDistinctValue(distinctValue);
896 uint64_t sequenceId = NetToHost(*fieldPtr++);
897 inResult.SetLabelExchangeSequenceId(sequenceId);
898 uint64_t commLabelCount = NetToHost(*fieldPtr++);
899 if (length < commLabelCount || (UINT32_MAX / COMM_LABEL_LENGTH) < commLabelCount) {
900 LOGE("[Proto][ParseLabel] commLabelCount=%" PRIu64 " invalid.", ULL(commLabelCount));
901 return -E_PARSE_FAIL;
902 }
903 // commLabelCount is expected to be not very large
904 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
905 commLabelCount * COMM_LABEL_LENGTH) {
906 LOGE("[Proto][ParseLabel] Length of Bytes Error, commLabelCount=%" PRIu64, ULL(commLabelCount));
907 return -E_LENGTH_ERROR;
908 }
909
910 // Get each commLabel
911 std::set<LabelType> commLabels;
912 auto bytePtr = reinterpret_cast<const uint8_t *>(fieldPtr);
913 for (uint64_t i = 0; i < commLabelCount; i++) {
914 // the length is checked just above
915 LabelType commLabel(bytePtr + i * COMM_LABEL_LENGTH, bytePtr + (i + 1) * COMM_LABEL_LENGTH);
916 if (commLabels.count(commLabel) != 0) {
917 LOGW("[Proto][ParseLabel] Duplicate Label Detected, commLabel=%.3s.", VEC_TO_STR(commLabel));
918 } else {
919 commLabels.insert(commLabel);
920 }
921 }
922 inResult.SetLatestCommLabels(commLabels);
923 return E_OK;
924 }
925
ParseLabelExchangeAck(const uint8_t * bytes,uint32_t length,ParseResult & inResult)926 int ProtocolProto::ParseLabelExchangeAck(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
927 {
928 // Check version at very first
929 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
930 return -E_LENGTH_ERROR;
931 }
932 auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
933 uint64_t version = NetToHost(*fieldPtr++);
934 if (version != PROTOCOL_VERSION) {
935 LOGE("[Proto][ParseLabelAck] Version=%" PRIu64 " not support.", ULL(version));
936 return -E_VERSION_NOT_SUPPORT;
937 }
938
939 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN) {
940 LOGE("[Proto][ParseLabelAck] Length of Bytes Error.");
941 return -E_LENGTH_ERROR;
942 }
943 uint64_t distinctValue = NetToHost(*fieldPtr++);
944 inResult.SetLabelExchangeDistinctValue(distinctValue);
945 uint64_t sequenceId = NetToHost(*fieldPtr++);
946 inResult.SetLabelExchangeSequenceId(sequenceId);
947 return E_OK;
948 }
949
950 // Note: framePhyHeader is in network endian
951 // This function aims at calculating and preparing each part of each packets
FrameFragmentation(const uint8_t * splitStartBytes,const FrameFragmentInfo & fragmentInfo,const CommPhyHeader & framePhyHeader,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)952 int ProtocolProto::FrameFragmentation(const uint8_t *splitStartBytes, const FrameFragmentInfo &fragmentInfo,
953 const CommPhyHeader &framePhyHeader, std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
954 {
955 // It can be guaranteed that fragCount >= 2 and also won't be too large
956 if (fragmentInfo.fragCount < MIN_FRAGMENT_COUNT) {
957 return -E_INVALID_ARGS;
958 }
959 outPieces.resize(fragmentInfo.fragCount); // Note: should use resize other than reserve
960 uint32_t quotient = fragmentInfo.splitLength / fragmentInfo.fragCount;
961 uint16_t remainder = fragmentInfo.splitLength % fragmentInfo.fragCount;
962 uint16_t fragNo = 0; // Fragment index start from 0
963 uint32_t byteOffset = 0;
964
965 for (auto &entry : outPieces) {
966 // subtract 1 for index
967 uint32_t pieceFragLen = (fragNo != fragmentInfo.fragCount - 1) ? quotient : (quotient + remainder);
968 uint32_t alignedFragLen = BYTE_8_ALIGN(pieceFragLen); // Add padding length
969 uint32_t pieceTotalLen = alignedFragLen + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
970
971 // Since exception is disabled, we have to check the vector size to assure that memory is truly allocated
972 entry.first.resize(pieceTotalLen + fragmentInfo.extendHeadSize); // Note: should use resize other than reserve
973 if (entry.first.size() != (pieceTotalLen + fragmentInfo.extendHeadSize)) {
974 LOGE("[Proto][FrameFrag] Resize failed for length=%" PRIu32, pieceTotalLen);
975 return -E_OUT_OF_MEMORY;
976 }
977
978 CommPhyHeader pktPhyHeader;
979 HeaderConverter::ConvertNetToHost(framePhyHeader, pktPhyHeader); // Restore to host endian
980
981 // The sum value need to be recalculated, and the packet is fragmented.
982 // The alignedFragLen is always larger than pieceFragLen
983 FillPhyHeaderLenInfo(pieceTotalLen, 0, PACKET_TYPE_FRAGMENTED, alignedFragLen - pieceFragLen, pktPhyHeader);
984 HeaderConverter::ConvertHostToNet(pktPhyHeader, pktPhyHeader);
985
986 CommPhyOptHeader pktPhyOptHeader = {static_cast<uint32_t>(fragmentInfo.splitLength + sizeof(CommPhyHeader)),
987 fragmentInfo.fragCount, fragNo};
988 HeaderConverter::ConvertHostToNet(pktPhyOptHeader, pktPhyOptHeader);
989 int err;
990 FragmentPacket packet;
991 uint8_t *ptrPacket = &(entry.first[0]);
992 if (fragmentInfo.extendHeadSize > 0) {
993 packet = {ptrPacket, fragmentInfo.extendHeadSize};
994 err = FillFragmentPacketExtendHead(fragmentInfo.oringinalBytesAddr, fragmentInfo.extendHeadSize, packet);
995 if (err != E_OK) {
996 return err;
997 }
998 ptrPacket += fragmentInfo.extendHeadSize;
999 }
1000 packet = {ptrPacket, static_cast<uint32_t>(entry.first.size()) - fragmentInfo.extendHeadSize};
1001 err = FillFragmentPacket(pktPhyHeader, pktPhyOptHeader, splitStartBytes + byteOffset,
1002 pieceFragLen, packet);
1003 entry.second = fragmentInfo.extendHeadSize;
1004 if (err != E_OK) {
1005 LOGE("[Proto][FrameFrag] Fill packet fail, fragCount=%" PRIu16 ", fragNo=%" PRIu16, fragmentInfo.fragCount,
1006 fragNo);
1007 return err;
1008 }
1009
1010 fragNo++;
1011 byteOffset += pieceFragLen;
1012 }
1013
1014 return E_OK;
1015 }
1016
FillFragmentPacketExtendHead(uint8_t * headBytesAddr,uint32_t headLen,FragmentPacket & outPacket)1017 int ProtocolProto::FillFragmentPacketExtendHead(uint8_t *headBytesAddr, uint32_t headLen, FragmentPacket &outPacket)
1018 {
1019 if (headLen > outPacket.leftLength) {
1020 LOGE("[Proto][FrameFrag] headLen less than leftLength");
1021 return -E_INVALID_ARGS;
1022 }
1023 errno_t retCode = memcpy_s(outPacket.ptrPacket, outPacket.leftLength, headBytesAddr, headLen);
1024 if (retCode != EOK) {
1025 LOGE("memcpy error:%d", retCode);
1026 return -E_SECUREC_ERROR;
1027 }
1028 return E_OK;
1029 }
1030
1031 // Note: phyHeader and phyOptHeader is in network endian
FillFragmentPacket(const CommPhyHeader & phyHeader,const CommPhyOptHeader & phyOptHeader,const uint8_t * fragBytes,uint32_t fragLen,FragmentPacket & outPacket)1032 int ProtocolProto::FillFragmentPacket(const CommPhyHeader &phyHeader, const CommPhyOptHeader &phyOptHeader,
1033 const uint8_t *fragBytes, uint32_t fragLen, FragmentPacket &outPacket)
1034 {
1035 if (outPacket.leftLength == 0) {
1036 return -E_INVALID_ARGS;
1037 }
1038 uint8_t *ptrPacket = outPacket.ptrPacket;
1039 uint32_t leftLength = outPacket.leftLength;
1040
1041 // leftLength is guaranteed to be no smaller than the sum of phyHeaderLen + phyOptHeaderLen + fragLen
1042 // So, there will be no redundant check during subtraction
1043 errno_t retCode = memcpy_s(ptrPacket, leftLength, &phyHeader, sizeof(CommPhyHeader));
1044 if (retCode != EOK) {
1045 return -E_SECUREC_ERROR;
1046 }
1047 ptrPacket += sizeof(CommPhyHeader);
1048 leftLength -= sizeof(CommPhyHeader);
1049
1050 retCode = memcpy_s(ptrPacket, leftLength, &phyOptHeader, sizeof(CommPhyOptHeader));
1051 if (retCode != EOK) {
1052 return -E_SECUREC_ERROR;
1053 }
1054 ptrPacket += sizeof(CommPhyOptHeader);
1055 leftLength -= sizeof(CommPhyOptHeader);
1056
1057 retCode = memcpy_s(ptrPacket, leftLength, fragBytes, fragLen);
1058 if (retCode != EOK) {
1059 return -E_SECUREC_ERROR;
1060 }
1061
1062 // Calculate sum and set sum field
1063 uint64_t sumResult = 0;
1064 int errCode = CalculateXorSum(outPacket.ptrPacket + LENGTH_BEFORE_SUM_RANGE,
1065 outPacket.leftLength - LENGTH_BEFORE_SUM_RANGE, sumResult);
1066 if (errCode != E_OK) {
1067 return -E_SUM_CALCULATE_FAIL;
1068 }
1069 auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(outPacket.ptrPacket);
1070 if (ptrPhyHeader == nullptr) {
1071 return -E_INVALID_ARGS;
1072 }
1073 ptrPhyHeader->checkSum = HostToNet(sumResult);
1074
1075 return E_OK;
1076 }
1077
GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> & extendHandle,uint32_t & headSize)1078 int ProtocolProto::GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> &extendHandle, uint32_t &headSize)
1079 {
1080 if (extendHandle != nullptr) {
1081 DBStatus status = extendHandle->GetHeadDataSize(headSize);
1082 if (status != DBStatus::OK) {
1083 LOGI("[Proto][ToSerial] get head data size failed,not permit to send");
1084 return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1085 }
1086 if (headSize > SerialBuffer::MAX_EXTEND_HEAD_LENGTH || headSize != BYTE_8_ALIGN(headSize)) {
1087 LOGI("[Proto][ToSerial] head data size is larger than 512 or not 8 byte align");
1088 return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1089 }
1090 return E_OK;
1091 }
1092 return E_OK;
1093 }
1094
FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> & extendHandle,SerialBuffer * buffer,uint32_t headSize)1095 int ProtocolProto::FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> &extendHandle, SerialBuffer *buffer,
1096 uint32_t headSize)
1097 {
1098 if (extendHandle != nullptr && headSize > 0) {
1099 if (buffer == nullptr) {
1100 return -E_INVALID_ARGS;
1101 }
1102 DBStatus status = extendHandle->FillHeadData(buffer->GetOringinalAddr(), headSize,
1103 buffer->GetSize() + headSize);
1104 if (status != DBStatus::OK) {
1105 LOGI("[Proto][ToSerial] fill head data failed");
1106 return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1107 }
1108 }
1109 return E_OK;
1110 }
1111
GetTransformFunc(uint32_t messageId,DistributedDB::TransformFunc & function)1112 int ProtocolProto::GetTransformFunc(uint32_t messageId, DistributedDB::TransformFunc &function)
1113 {
1114 std::shared_lock<std::shared_mutex> autoLock(msgIdMutex_);
1115 const auto &entry = msgIdMapFunc_.find(messageId);
1116 if (entry == msgIdMapFunc_.end()) {
1117 return -E_NOT_REGISTER;
1118 }
1119 function = entry->second;
1120 return E_OK;
1121 }
1122 } // namespace DistributedDB
1123