1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "soft_bus_base_socket.h"
17 
18 #include <cinttypes>
19 
20 #include "remote_connect_listener_manager.h"
21 
22 #define LOG_TAG "USER_AUTH_SA"
23 namespace OHOS {
24 namespace UserIam {
25 namespace UserAuth {
26 using namespace OHOS::DistributedHardware;
27 const std::string USERIAM_PACKAGE_NAME = "ohos.useriam";
28 static constexpr uint32_t REPLY_TIMER_LEN_MS = 5 * 1000; // 5s
29 static constexpr uint32_t INVALID_TIMER_ID = 0;
30 static std::recursive_mutex g_seqMutex;
31 static uint32_t g_messageSeq = 0;
32 
BaseSocket(const int32_t socketId)33 BaseSocket::BaseSocket(const int32_t socketId)
34     : socketId_(socketId)
35 {
36     IAM_LOGI("create socket id %{public}d.", socketId_);
37 }
38 
~BaseSocket()39 BaseSocket::~BaseSocket()
40 {
41     Shutdown(socketId_);
42     IAM_LOGI("close socket id %{public}d.", socketId_);
43 }
44 
GetSocketId()45 int32_t BaseSocket::GetSocketId()
46 {
47     return socketId_;
48 }
49 
InsertMsgCallback(uint32_t messageSeq,const std::string & connectionName,MsgCallback & callback,uint32_t timerId)50 void BaseSocket::InsertMsgCallback(uint32_t messageSeq, const std::string &connectionName,
51     MsgCallback &callback, uint32_t timerId)
52 {
53     IAM_LOGD("start. messageSeq:%{public}u, timerId:%{public}u", messageSeq, timerId);
54     IF_FALSE_LOGE_AND_RETURN(callback != nullptr);
55 
56     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
57     CallbackInfo callbackInfo = {
58         .connectionName = connectionName,
59         .msgCallback = callback,
60         .timerId = timerId,
61         .sendTime = std::chrono::steady_clock::now()
62     };
63     callbackMap_.insert(std::pair<int32_t, CallbackInfo>(messageSeq, callbackInfo));
64 }
65 
RemoveMsgCallback(uint32_t messageSeq)66 void BaseSocket::RemoveMsgCallback(uint32_t messageSeq)
67 {
68     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
69     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
70     callbackMap_.erase(messageSeq);
71 }
72 
GetConnectionName(uint32_t messageSeq)73 std::string BaseSocket::GetConnectionName(uint32_t messageSeq)
74 {
75     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
76     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
77     std::string connectionName;
78     auto iter = callbackMap_.find(messageSeq);
79     if (iter != callbackMap_.end()) {
80         connectionName = iter->second.connectionName;
81     }
82     return connectionName;
83 }
84 
GetMsgCallback(uint32_t messageSeq)85 MsgCallback BaseSocket::GetMsgCallback(uint32_t messageSeq)
86 {
87     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
88     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
89     MsgCallback callback = nullptr;
90     auto iter = callbackMap_.find(messageSeq);
91     if (iter != callbackMap_.end()) {
92         callback = iter->second.msgCallback;
93     }
94     return callback;
95 }
96 
PrintTransferDuration(uint32_t messageSeq)97 void BaseSocket::PrintTransferDuration(uint32_t messageSeq)
98 {
99     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
100     auto iter = callbackMap_.find(messageSeq);
101     if (iter == callbackMap_.end()) {
102         IAM_LOGE("message seq not found");
103         return;
104     }
105 
106     auto receiveAckTime = std::chrono::steady_clock::now();
107     auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(receiveAckTime - iter->second.sendTime);
108     IAM_LOGI("messageSeq:%{public}u MessageTransferDuration:%{public}" PRIu64 " ms", messageSeq,
109         static_cast<uint64_t>(duration.count()));
110 }
111 
GetReplyTimer(uint32_t messageSeq)112 uint32_t BaseSocket::GetReplyTimer(uint32_t messageSeq)
113 {
114     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
115     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
116     uint32_t timerId = 0;
117     auto iter = callbackMap_.find(messageSeq);
118     if (iter != callbackMap_.end()) {
119         timerId = iter->second.timerId;
120     }
121     return timerId;
122 }
123 
StartReplyTimer(uint32_t messageSeq)124 uint32_t BaseSocket::StartReplyTimer(uint32_t messageSeq)
125 {
126     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
127     uint32_t timerId = GetReplyTimer(messageSeq);
128     if (timerId != INVALID_TIMER_ID) {
129         IAM_LOGI("timer is already start");
130         return timerId;
131     }
132 
133     timerId = RelativeTimer::GetInstance().Register(
134         [weakSelf = weak_from_this(), messageSeq, socketId = socketId_] {
135             auto self = weakSelf.lock();
136             if (self == nullptr) {
137                 IAM_LOGE("socket %{public}d is released", socketId);
138                 return;
139             }
140             self->ReplyTimerTimeOut(messageSeq);
141         },
142         REPLY_TIMER_LEN_MS);
143 
144     return timerId;
145 }
146 
StopReplyTimer(uint32_t messageSeq)147 void BaseSocket::StopReplyTimer(uint32_t messageSeq)
148 {
149     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
150     uint32_t timerId = GetReplyTimer(messageSeq);
151     if (timerId == INVALID_TIMER_ID) {
152         IAM_LOGI("timer is already stop");
153         return;
154     }
155 
156     RelativeTimer::GetInstance().Unregister(timerId);
157 }
158 
ReplyTimerTimeOut(uint32_t messageSeq)159 void BaseSocket::ReplyTimerTimeOut(uint32_t messageSeq)
160 {
161     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
162     std::string connectionName = GetConnectionName(messageSeq);
163     if (connectionName.empty()) {
164         IAM_LOGE("GetMsgCallback connectionName fail");
165         return;
166     }
167 
168     RemoteConnectListenerManager::GetInstance().OnConnectionDown(connectionName);
169     RemoveMsgCallback(messageSeq);
170     IAM_LOGI("reply timer is timeout, messageSeq:%{public}u", messageSeq);
171 }
172 
GetMessageSeq()173 int32_t BaseSocket::GetMessageSeq()
174 {
175     IAM_LOGD("start.");
176     std::lock_guard<std::recursive_mutex> lock(g_seqMutex);
177     g_messageSeq++;
178     return g_messageSeq;
179 }
180 
SetDeviceNetworkId(const std::string networkId,std::shared_ptr<Attributes> & attributes)181 ResultCode BaseSocket::SetDeviceNetworkId(const std::string networkId, std::shared_ptr<Attributes> &attributes)
182 {
183     IAM_LOGD("start.");
184     IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
185 
186     bool setDeviceNetworkIdRet = attributes->SetStringValue(Attributes::ATTR_COLLECTOR_NETWORK_ID, networkId);
187     if (setDeviceNetworkIdRet == false) {
188         IAM_LOGE("SetStringValue fail");
189         return GENERAL_ERROR;
190     }
191 
192     return SUCCESS;
193 }
194 
SendRequest(const int32_t socketId,const std::string & connectionName,const std::string & srcEndPoint,const std::string & destEndPoint,const std::shared_ptr<Attributes> & attributes,MsgCallback & callback)195 ResultCode BaseSocket::SendRequest(const int32_t socketId, const std::string &connectionName,
196     const std::string &srcEndPoint, const std::string &destEndPoint, const std::shared_ptr<Attributes> &attributes,
197     MsgCallback &callback)
198 {
199     IAM_LOGD("start.");
200     IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
201     IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
202 
203     int32_t messageSeq = GetMessageSeq();
204     std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
205         connectionName, srcEndPoint, destEndPoint, attributes);
206     if (softBusMessage == nullptr) {
207         IAM_LOGE("softBusMessage is nullptr");
208         return GENERAL_ERROR;
209     }
210 
211     std::shared_ptr<Attributes> request = softBusMessage->CreateMessage(false);
212     if (request == nullptr) {
213         IAM_LOGE("creatMessage fail");
214         return GENERAL_ERROR;
215     }
216 
217     std::vector<uint8_t> data = request->Serialize();
218     int ret = SendBytes(socketId, data.data(), data.size());
219     if (ret != SUCCESS) {
220         IAM_LOGE("fail to send message, result= %{public}d", ret);
221         return GENERAL_ERROR;
222     }
223 
224     uint32_t timerId = StartReplyTimer(messageSeq);
225     if (timerId == INVALID_TIMER_ID) {
226         IAM_LOGE("create reply timer fail");
227         return GENERAL_ERROR;
228     }
229 
230     InsertMsgCallback(messageSeq, connectionName, callback, timerId);
231     IAM_LOGI("SendRequest success.");
232     return SUCCESS;
233 }
234 
SendResponse(const int32_t socketId,const std::string & connectionName,const std::string & srcEndPoint,const std::string & destEndPoint,const std::shared_ptr<Attributes> & attributes,uint32_t messageSeq)235 ResultCode BaseSocket::SendResponse(const int32_t socketId, const std::string &connectionName,
236     const std::string &srcEndPoint, const std::string &destEndPoint, const std::shared_ptr<Attributes> &attributes,
237     uint32_t messageSeq)
238 {
239     IAM_LOGD("start.");
240     IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
241     IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
242 
243     std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
244         connectionName, srcEndPoint, destEndPoint, attributes);
245     if (softBusMessage == nullptr) {
246         IAM_LOGE("softBusMessage is nullptr");
247         return GENERAL_ERROR;
248     }
249 
250     std::shared_ptr<Attributes> response = softBusMessage->CreateMessage(true);
251     if (response == nullptr) {
252         IAM_LOGE("creatMessage fail");
253         return GENERAL_ERROR;
254     }
255 
256     std::vector<uint8_t> data = response->Serialize();
257     int ret = SendBytes(socketId, data.data(), data.size());
258     if (ret != SUCCESS) {
259         IAM_LOGE("fail to send message, result= %{public}d", ret);
260         return GENERAL_ERROR;
261     }
262 
263     IAM_LOGI("SendResponse success.");
264     return SUCCESS;
265 }
266 
ParseMessage(const std::string & networkId,void * message,uint32_t messageLen)267 std::shared_ptr<SoftBusMessage> BaseSocket::ParseMessage(const std::string &networkId,
268     void *message, uint32_t messageLen)
269 {
270     IAM_LOGD("start.");
271     IF_FALSE_LOGE_AND_RETURN_VAL(message != nullptr, nullptr);
272     IF_FALSE_LOGE_AND_RETURN_VAL(messageLen != 0, nullptr);
273 
274     std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(0, "", "", "", nullptr);
275     if (softBusMessage == nullptr) {
276         IAM_LOGE("softBusMessage is nullptr");
277         return nullptr;
278     }
279 
280     std::shared_ptr<Attributes> attributes = softBusMessage->ParseMessage(message, messageLen);
281     if (attributes == nullptr) {
282         IAM_LOGE("parseMessage fail");
283         return nullptr;
284     }
285 
286     int32_t ret = SetDeviceNetworkId(networkId, attributes);
287     if (ret != SUCCESS) {
288         IAM_LOGE("SetDeviceNetworkId fail");
289         return nullptr;
290     }
291 
292     IAM_LOGD("ParseMessage success.");
293     return softBusMessage;
294 }
295 
ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage,std::shared_ptr<Attributes> response)296 void BaseSocket::ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage, std::shared_ptr<Attributes> response)
297 {
298     IF_FALSE_LOGE_AND_RETURN(softBusMessage != nullptr);
299     IF_FALSE_LOGE_AND_RETURN(response != nullptr);
300 
301     bool setResultCode = response->SetInt32Value(Attributes::ATTR_RESULT_CODE, GENERAL_ERROR);
302     IF_FALSE_LOGE_AND_RETURN(setResultCode);
303 
304     uint32_t messageVersion = softBusMessage->GetMessageVersion();
305     if (messageVersion != DEFAULT_MESSAGE_VERSION) {
306         IAM_LOGE("support message version %{public}u, receive message version %{public}u", DEFAULT_MESSAGE_VERSION,
307             messageVersion);
308         std::vector<uint32_t> supportedVersions = { DEFAULT_MESSAGE_VERSION };
309         bool setSupportedVersionsRet = response->SetUint32ArrayValue(Attributes::ATTR_SUPPORTED_MSG_VERSION,
310             supportedVersions);
311         IF_FALSE_LOGE_AND_RETURN(setSupportedVersionsRet);
312         return;
313     }
314 
315     std::string connectionName = softBusMessage->GetConnectionName();
316     std::string destEndPoint = softBusMessage->GetDestEndPoint();
317 
318     std::shared_ptr<ConnectionListener> connectionListener =
319         RemoteConnectListenerManager::GetInstance().FindListener(connectionName, destEndPoint);
320     if (connectionListener == nullptr) {
321         IAM_LOGE("connectionListener is nullptr");
322         return;
323     }
324 
325     auto beginTime = std::chrono::steady_clock::now();
326     connectionListener->OnMessage(connectionName, destEndPoint, softBusMessage->GetAttributes(), response);
327     auto endTime = std::chrono::steady_clock::now();
328     auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - beginTime);
329     IAM_LOGI("messageSeq:%{public}u ProcessMessageDuration:%{public}" PRIu64 " ms", softBusMessage->GetMessageSeq(),
330         static_cast<uint64_t>(duration.count()));
331 }
332 
ProcDataReceive(const int32_t socketId,std::shared_ptr<SoftBusMessage> & softBusMessage)333 ResultCode BaseSocket::ProcDataReceive(const int32_t socketId, std::shared_ptr<SoftBusMessage> &softBusMessage)
334 {
335     IAM_LOGD("start.");
336     IF_FALSE_LOGE_AND_RETURN_VAL(softBusMessage != nullptr, INVALID_PARAMETERS);
337     IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
338 
339     std::shared_ptr<Attributes> request = softBusMessage->GetAttributes();
340     if (request == nullptr) {
341         IAM_LOGE("GetAttributes fail");
342         return GENERAL_ERROR;
343     }
344 
345     uint32_t messageSeq = softBusMessage->GetMessageSeq();
346     bool ack = softBusMessage->GetAckFlag();
347     if (ack == true) {
348         PrintTransferDuration(messageSeq);
349         MsgCallback callback = GetMsgCallback(messageSeq);
350         if (callback == nullptr) {
351             IAM_LOGE("GetMsgCallback fail");
352             return GENERAL_ERROR;
353         }
354 
355         callback(request);
356         StopReplyTimer(messageSeq);
357         RemoveMsgCallback(messageSeq);
358     } else {
359         std::string connectionName = softBusMessage->GetConnectionName();
360         std::string srcEndPoint = softBusMessage->GetSrcEndPoint();
361         std::string destEndPoint = softBusMessage->GetDestEndPoint();
362 
363         std::shared_ptr<Attributes> response = Common::MakeShared<Attributes>();
364         if (response == nullptr) {
365             IAM_LOGE("create fail");
366             return GENERAL_ERROR;
367         }
368 
369         ProcessMessage(softBusMessage, response);
370 
371         SendResponse(socketId, connectionName, destEndPoint, srcEndPoint, response, messageSeq);
372     }
373 
374     IAM_LOGI("ProcDataReceive success.");
375     return SUCCESS;
376 }
377 } // namespace UserAuth
378 } // namespace UserIam
379 } // namespace OHOS