1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "soft_bus_base_socket.h"
17
18 #include <cinttypes>
19
20 #include "remote_connect_listener_manager.h"
21
22 #define LOG_TAG "USER_AUTH_SA"
23 namespace OHOS {
24 namespace UserIam {
25 namespace UserAuth {
26 using namespace OHOS::DistributedHardware;
27 const std::string USERIAM_PACKAGE_NAME = "ohos.useriam";
28 static constexpr uint32_t REPLY_TIMER_LEN_MS = 5 * 1000; // 5s
29 static constexpr uint32_t INVALID_TIMER_ID = 0;
30 static std::recursive_mutex g_seqMutex;
31 static uint32_t g_messageSeq = 0;
32
BaseSocket(const int32_t socketId)33 BaseSocket::BaseSocket(const int32_t socketId)
34 : socketId_(socketId)
35 {
36 IAM_LOGI("create socket id %{public}d.", socketId_);
37 }
38
~BaseSocket()39 BaseSocket::~BaseSocket()
40 {
41 Shutdown(socketId_);
42 IAM_LOGI("close socket id %{public}d.", socketId_);
43 }
44
GetSocketId()45 int32_t BaseSocket::GetSocketId()
46 {
47 return socketId_;
48 }
49
InsertMsgCallback(uint32_t messageSeq,const std::string & connectionName,MsgCallback & callback,uint32_t timerId)50 void BaseSocket::InsertMsgCallback(uint32_t messageSeq, const std::string &connectionName,
51 MsgCallback &callback, uint32_t timerId)
52 {
53 IAM_LOGD("start. messageSeq:%{public}u, timerId:%{public}u", messageSeq, timerId);
54 IF_FALSE_LOGE_AND_RETURN(callback != nullptr);
55
56 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
57 CallbackInfo callbackInfo = {
58 .connectionName = connectionName,
59 .msgCallback = callback,
60 .timerId = timerId,
61 .sendTime = std::chrono::steady_clock::now()
62 };
63 callbackMap_.insert(std::pair<int32_t, CallbackInfo>(messageSeq, callbackInfo));
64 }
65
RemoveMsgCallback(uint32_t messageSeq)66 void BaseSocket::RemoveMsgCallback(uint32_t messageSeq)
67 {
68 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
69 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
70 callbackMap_.erase(messageSeq);
71 }
72
GetConnectionName(uint32_t messageSeq)73 std::string BaseSocket::GetConnectionName(uint32_t messageSeq)
74 {
75 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
76 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
77 std::string connectionName;
78 auto iter = callbackMap_.find(messageSeq);
79 if (iter != callbackMap_.end()) {
80 connectionName = iter->second.connectionName;
81 }
82 return connectionName;
83 }
84
GetMsgCallback(uint32_t messageSeq)85 MsgCallback BaseSocket::GetMsgCallback(uint32_t messageSeq)
86 {
87 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
88 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
89 MsgCallback callback = nullptr;
90 auto iter = callbackMap_.find(messageSeq);
91 if (iter != callbackMap_.end()) {
92 callback = iter->second.msgCallback;
93 }
94 return callback;
95 }
96
PrintTransferDuration(uint32_t messageSeq)97 void BaseSocket::PrintTransferDuration(uint32_t messageSeq)
98 {
99 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
100 auto iter = callbackMap_.find(messageSeq);
101 if (iter == callbackMap_.end()) {
102 IAM_LOGE("message seq not found");
103 return;
104 }
105
106 auto receiveAckTime = std::chrono::steady_clock::now();
107 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(receiveAckTime - iter->second.sendTime);
108 IAM_LOGI("messageSeq:%{public}u MessageTransferDuration:%{public}" PRIu64 " ms", messageSeq,
109 static_cast<uint64_t>(duration.count()));
110 }
111
GetReplyTimer(uint32_t messageSeq)112 uint32_t BaseSocket::GetReplyTimer(uint32_t messageSeq)
113 {
114 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
115 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
116 uint32_t timerId = 0;
117 auto iter = callbackMap_.find(messageSeq);
118 if (iter != callbackMap_.end()) {
119 timerId = iter->second.timerId;
120 }
121 return timerId;
122 }
123
StartReplyTimer(uint32_t messageSeq)124 uint32_t BaseSocket::StartReplyTimer(uint32_t messageSeq)
125 {
126 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
127 uint32_t timerId = GetReplyTimer(messageSeq);
128 if (timerId != INVALID_TIMER_ID) {
129 IAM_LOGI("timer is already start");
130 return timerId;
131 }
132
133 timerId = RelativeTimer::GetInstance().Register(
134 [weakSelf = weak_from_this(), messageSeq, socketId = socketId_] {
135 auto self = weakSelf.lock();
136 if (self == nullptr) {
137 IAM_LOGE("socket %{public}d is released", socketId);
138 return;
139 }
140 self->ReplyTimerTimeOut(messageSeq);
141 },
142 REPLY_TIMER_LEN_MS);
143
144 return timerId;
145 }
146
StopReplyTimer(uint32_t messageSeq)147 void BaseSocket::StopReplyTimer(uint32_t messageSeq)
148 {
149 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
150 uint32_t timerId = GetReplyTimer(messageSeq);
151 if (timerId == INVALID_TIMER_ID) {
152 IAM_LOGI("timer is already stop");
153 return;
154 }
155
156 RelativeTimer::GetInstance().Unregister(timerId);
157 }
158
ReplyTimerTimeOut(uint32_t messageSeq)159 void BaseSocket::ReplyTimerTimeOut(uint32_t messageSeq)
160 {
161 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
162 std::string connectionName = GetConnectionName(messageSeq);
163 if (connectionName.empty()) {
164 IAM_LOGE("GetMsgCallback connectionName fail");
165 return;
166 }
167
168 RemoteConnectListenerManager::GetInstance().OnConnectionDown(connectionName);
169 RemoveMsgCallback(messageSeq);
170 IAM_LOGI("reply timer is timeout, messageSeq:%{public}u", messageSeq);
171 }
172
GetMessageSeq()173 int32_t BaseSocket::GetMessageSeq()
174 {
175 IAM_LOGD("start.");
176 std::lock_guard<std::recursive_mutex> lock(g_seqMutex);
177 g_messageSeq++;
178 return g_messageSeq;
179 }
180
SetDeviceNetworkId(const std::string networkId,std::shared_ptr<Attributes> & attributes)181 ResultCode BaseSocket::SetDeviceNetworkId(const std::string networkId, std::shared_ptr<Attributes> &attributes)
182 {
183 IAM_LOGD("start.");
184 IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
185
186 bool setDeviceNetworkIdRet = attributes->SetStringValue(Attributes::ATTR_COLLECTOR_NETWORK_ID, networkId);
187 if (setDeviceNetworkIdRet == false) {
188 IAM_LOGE("SetStringValue fail");
189 return GENERAL_ERROR;
190 }
191
192 return SUCCESS;
193 }
194
SendRequest(const int32_t socketId,const std::string & connectionName,const std::string & srcEndPoint,const std::string & destEndPoint,const std::shared_ptr<Attributes> & attributes,MsgCallback & callback)195 ResultCode BaseSocket::SendRequest(const int32_t socketId, const std::string &connectionName,
196 const std::string &srcEndPoint, const std::string &destEndPoint, const std::shared_ptr<Attributes> &attributes,
197 MsgCallback &callback)
198 {
199 IAM_LOGD("start.");
200 IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
201 IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
202
203 int32_t messageSeq = GetMessageSeq();
204 std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
205 connectionName, srcEndPoint, destEndPoint, attributes);
206 if (softBusMessage == nullptr) {
207 IAM_LOGE("softBusMessage is nullptr");
208 return GENERAL_ERROR;
209 }
210
211 std::shared_ptr<Attributes> request = softBusMessage->CreateMessage(false);
212 if (request == nullptr) {
213 IAM_LOGE("creatMessage fail");
214 return GENERAL_ERROR;
215 }
216
217 std::vector<uint8_t> data = request->Serialize();
218 int ret = SendBytes(socketId, data.data(), data.size());
219 if (ret != SUCCESS) {
220 IAM_LOGE("fail to send message, result= %{public}d", ret);
221 return GENERAL_ERROR;
222 }
223
224 uint32_t timerId = StartReplyTimer(messageSeq);
225 if (timerId == INVALID_TIMER_ID) {
226 IAM_LOGE("create reply timer fail");
227 return GENERAL_ERROR;
228 }
229
230 InsertMsgCallback(messageSeq, connectionName, callback, timerId);
231 IAM_LOGI("SendRequest success.");
232 return SUCCESS;
233 }
234
SendResponse(const int32_t socketId,const std::string & connectionName,const std::string & srcEndPoint,const std::string & destEndPoint,const std::shared_ptr<Attributes> & attributes,uint32_t messageSeq)235 ResultCode BaseSocket::SendResponse(const int32_t socketId, const std::string &connectionName,
236 const std::string &srcEndPoint, const std::string &destEndPoint, const std::shared_ptr<Attributes> &attributes,
237 uint32_t messageSeq)
238 {
239 IAM_LOGD("start.");
240 IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
241 IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
242
243 std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
244 connectionName, srcEndPoint, destEndPoint, attributes);
245 if (softBusMessage == nullptr) {
246 IAM_LOGE("softBusMessage is nullptr");
247 return GENERAL_ERROR;
248 }
249
250 std::shared_ptr<Attributes> response = softBusMessage->CreateMessage(true);
251 if (response == nullptr) {
252 IAM_LOGE("creatMessage fail");
253 return GENERAL_ERROR;
254 }
255
256 std::vector<uint8_t> data = response->Serialize();
257 int ret = SendBytes(socketId, data.data(), data.size());
258 if (ret != SUCCESS) {
259 IAM_LOGE("fail to send message, result= %{public}d", ret);
260 return GENERAL_ERROR;
261 }
262
263 IAM_LOGI("SendResponse success.");
264 return SUCCESS;
265 }
266
ParseMessage(const std::string & networkId,void * message,uint32_t messageLen)267 std::shared_ptr<SoftBusMessage> BaseSocket::ParseMessage(const std::string &networkId,
268 void *message, uint32_t messageLen)
269 {
270 IAM_LOGD("start.");
271 IF_FALSE_LOGE_AND_RETURN_VAL(message != nullptr, nullptr);
272 IF_FALSE_LOGE_AND_RETURN_VAL(messageLen != 0, nullptr);
273
274 std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(0, "", "", "", nullptr);
275 if (softBusMessage == nullptr) {
276 IAM_LOGE("softBusMessage is nullptr");
277 return nullptr;
278 }
279
280 std::shared_ptr<Attributes> attributes = softBusMessage->ParseMessage(message, messageLen);
281 if (attributes == nullptr) {
282 IAM_LOGE("parseMessage fail");
283 return nullptr;
284 }
285
286 int32_t ret = SetDeviceNetworkId(networkId, attributes);
287 if (ret != SUCCESS) {
288 IAM_LOGE("SetDeviceNetworkId fail");
289 return nullptr;
290 }
291
292 IAM_LOGD("ParseMessage success.");
293 return softBusMessage;
294 }
295
ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage,std::shared_ptr<Attributes> response)296 void BaseSocket::ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage, std::shared_ptr<Attributes> response)
297 {
298 IF_FALSE_LOGE_AND_RETURN(softBusMessage != nullptr);
299 IF_FALSE_LOGE_AND_RETURN(response != nullptr);
300
301 bool setResultCode = response->SetInt32Value(Attributes::ATTR_RESULT_CODE, GENERAL_ERROR);
302 IF_FALSE_LOGE_AND_RETURN(setResultCode);
303
304 uint32_t messageVersion = softBusMessage->GetMessageVersion();
305 if (messageVersion != DEFAULT_MESSAGE_VERSION) {
306 IAM_LOGE("support message version %{public}u, receive message version %{public}u", DEFAULT_MESSAGE_VERSION,
307 messageVersion);
308 std::vector<uint32_t> supportedVersions = { DEFAULT_MESSAGE_VERSION };
309 bool setSupportedVersionsRet = response->SetUint32ArrayValue(Attributes::ATTR_SUPPORTED_MSG_VERSION,
310 supportedVersions);
311 IF_FALSE_LOGE_AND_RETURN(setSupportedVersionsRet);
312 return;
313 }
314
315 std::string connectionName = softBusMessage->GetConnectionName();
316 std::string destEndPoint = softBusMessage->GetDestEndPoint();
317
318 std::shared_ptr<ConnectionListener> connectionListener =
319 RemoteConnectListenerManager::GetInstance().FindListener(connectionName, destEndPoint);
320 if (connectionListener == nullptr) {
321 IAM_LOGE("connectionListener is nullptr");
322 return;
323 }
324
325 auto beginTime = std::chrono::steady_clock::now();
326 connectionListener->OnMessage(connectionName, destEndPoint, softBusMessage->GetAttributes(), response);
327 auto endTime = std::chrono::steady_clock::now();
328 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - beginTime);
329 IAM_LOGI("messageSeq:%{public}u ProcessMessageDuration:%{public}" PRIu64 " ms", softBusMessage->GetMessageSeq(),
330 static_cast<uint64_t>(duration.count()));
331 }
332
ProcDataReceive(const int32_t socketId,std::shared_ptr<SoftBusMessage> & softBusMessage)333 ResultCode BaseSocket::ProcDataReceive(const int32_t socketId, std::shared_ptr<SoftBusMessage> &softBusMessage)
334 {
335 IAM_LOGD("start.");
336 IF_FALSE_LOGE_AND_RETURN_VAL(softBusMessage != nullptr, INVALID_PARAMETERS);
337 IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
338
339 std::shared_ptr<Attributes> request = softBusMessage->GetAttributes();
340 if (request == nullptr) {
341 IAM_LOGE("GetAttributes fail");
342 return GENERAL_ERROR;
343 }
344
345 uint32_t messageSeq = softBusMessage->GetMessageSeq();
346 bool ack = softBusMessage->GetAckFlag();
347 if (ack == true) {
348 PrintTransferDuration(messageSeq);
349 MsgCallback callback = GetMsgCallback(messageSeq);
350 if (callback == nullptr) {
351 IAM_LOGE("GetMsgCallback fail");
352 return GENERAL_ERROR;
353 }
354
355 callback(request);
356 StopReplyTimer(messageSeq);
357 RemoveMsgCallback(messageSeq);
358 } else {
359 std::string connectionName = softBusMessage->GetConnectionName();
360 std::string srcEndPoint = softBusMessage->GetSrcEndPoint();
361 std::string destEndPoint = softBusMessage->GetDestEndPoint();
362
363 std::shared_ptr<Attributes> response = Common::MakeShared<Attributes>();
364 if (response == nullptr) {
365 IAM_LOGE("create fail");
366 return GENERAL_ERROR;
367 }
368
369 ProcessMessage(softBusMessage, response);
370
371 SendResponse(socketId, connectionName, destEndPoint, srcEndPoint, response, messageSeq);
372 }
373
374 IAM_LOGI("ProcDataReceive success.");
375 return SUCCESS;
376 }
377 } // namespace UserAuth
378 } // namespace UserIam
379 } // namespace OHOS