1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tmessenger.h"
17 
18 #include <algorithm>
19 #include <chrono>
20 #include <cinttypes>
21 #include <thread>
22 
23 #include "common.h"
24 
25 namespace OHOS {
26 static constexpr uint32_t WAIT_RESP_TIME = 1;
27 
Encode() const28 std::string Request::Encode() const
29 {
30     return std::to_string(static_cast<int32_t>(cmd_));
31 }
32 
Decode(const std::string & data)33 std::shared_ptr<Request> Request::Decode(const std::string &data)
34 {
35     if (data.empty()) {
36         LOGE("the data is empty");
37         return nullptr;
38     }
39 
40     Cmd cmd = static_cast<Cmd>(std::stoi(data));
41     if (cmd < Cmd::QUERY_RESULT || cmd > Cmd::QUERY_RESULT) {
42         LOGE("invalid cmd=%d", static_cast<int32_t>(cmd));
43         return nullptr;
44     }
45     return std::make_shared<Request>(cmd);
46 }
47 
Encode() const48 std::string Response::Encode() const
49 {
50     std::string data = std::to_string(isEncrypt_ ? 1 : 0);
51     return data + SEPARATOR + recvData_;
52 }
53 
Decode(const std::string & data)54 std::shared_ptr<Response> Response::Decode(const std::string &data)
55 {
56     if (data.empty()) {
57         LOGE("the data is empty");
58         return nullptr;
59     }
60 
61     size_t pos = data.find(SEPARATOR);
62     if (pos == std::string::npos) {
63         LOGE("can not find separator in the string data");
64         return nullptr;
65     }
66 
67     int32_t isEncryptVal = static_cast<int32_t>(std::stoi(data.substr(0, pos)));
68     bool isEncrypt = (isEncryptVal == 1);
69     std::string recvData = data.substr(pos + 1);
70 
71     return std::make_shared<Response>(isEncrypt, recvData);
72 }
73 
~Message()74 Message::~Message()
75 {
76     if (msgType_ == MsgType::MSG_SEQ && request != nullptr) {
77         delete request;
78     }
79     if (msgType_ == MsgType::MSG_RSP && response != nullptr) {
80         delete response;
81     }
82 }
83 
Encode() const84 std::string Message::Encode() const
85 {
86     std::string data = std::to_string(static_cast<int32_t>(msgType_));
87     switch (msgType_) {
88         case MsgType::MSG_SEQ:
89             return request == nullptr ? "" : data + SEPARATOR + request->Encode();
90         case MsgType::MSG_RSP:
91             return response == nullptr ? "" : data + SEPARATOR + response->Encode();
92         default:
93             LOGE("invalid msgType=%d", static_cast<int32_t>(msgType_));
94             return "";
95     }
96 }
97 
Decode(const std::string & data)98 std::shared_ptr<Message> Message::Decode(const std::string &data)
99 {
100     size_t pos = data.find(SEPARATOR);
101     if (pos == std::string::npos) {
102         return nullptr;
103     }
104 
105     MsgType msgType = static_cast<MsgType>(std::stoi(data.substr(0, pos)));
106     switch (msgType) {
107         case MsgType::MSG_SEQ: {
108             std::shared_ptr<Request> req = Request::Decode(data.substr(pos + 1));
109             if (req == nullptr) {
110                 return nullptr;
111             }
112             return std::make_shared<Message>(*req);
113         }
114         case MsgType::MSG_RSP: {
115             std::shared_ptr<Response> rsp = Response::Decode(data.substr(pos + 1));
116             if (rsp == nullptr) {
117                 return nullptr;
118             }
119             return std::make_shared<Message>(*rsp);
120         }
121         default:
122             LOGE("invalid msgType=%d", static_cast<int32_t>(msgType));
123             return nullptr;
124     }
125 }
126 
Open(const std::string & pkgName,const std::string & myName,const std::string & peerName,bool isServer)127 int32_t TMessenger::Open(
128     const std::string &pkgName, const std::string &myName, const std::string &peerName, bool isServer)
129 {
130     isServer_ = isServer;
131     return isServer_ ? StartListen(pkgName, myName) : StartConnect(pkgName, myName, peerName);
132 }
133 
Close()134 void TMessenger::Close()
135 {
136     if (socket_ > 0) {
137         Shutdown(socket_);
138         socket_ = -1;
139     }
140 
141     if (listenSocket_ > 0) {
142         Shutdown(listenSocket_);
143         listenSocket_ = -1;
144     }
145 
146     pkgName_.clear();
147     myName_.clear();
148     peerName_.clear();
149     peerNetworkId_.clear();
150     msgList_.clear();
151 }
152 
StartListen(const std::string & pkgName,const std::string & myName)153 int32_t TMessenger::StartListen(const std::string &pkgName, const std::string &myName)
154 {
155     if (listenSocket_ > 0) {
156         return SOFTBUS_OK;
157     }
158 
159     SocketInfo info = {
160         .pkgName = (char *)(pkgName.c_str()),
161         .name = (char *)(myName.c_str()),
162     };
163     int32_t socket = Socket(info);
164     if (socket <= 0) {
165         LOGE("failed to create socket, ret=%d", socket);
166         return socket;
167     }
168     LOGI("create listen socket=%d", socket);
169 
170     QosTV qosInfo[] = {
171         { .qos = QOS_TYPE_MIN_BW,      .value = 80   },
172         { .qos = QOS_TYPE_MAX_LATENCY, .value = 4000 },
173         { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 },
174     };
175     static ISocketListener listener = {
176         .OnBind = TMessenger::OnBind,
177         .OnMessage = TMessenger::OnMessage,
178         .OnShutdown = TMessenger::OnShutdown,
179     };
180 
181     int32_t ret = Listen(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
182     if (ret != SOFTBUS_OK) {
183         LOGE("failed to listen, socket=%d", socket);
184         Shutdown(socket);
185         return ret;
186     }
187     listenSocket_ = socket;
188     pkgName_ = pkgName;
189     myName_ = myName;
190     return SOFTBUS_OK;
191 }
192 
StartConnect(const std::string & pkgName,const std::string & myName,const std::string & peerName)193 int32_t TMessenger::StartConnect(const std::string &pkgName, const std::string &myName, const std::string &peerName)
194 {
195     if (socket_ > 0) {
196         return SOFTBUS_OK;
197     }
198 
199     SocketInfo info = {
200         .pkgName = const_cast<char *>(pkgName.c_str()),
201         .name = const_cast<char *>(myName.c_str()),
202         .peerName = const_cast<char *>(peerName.c_str()),
203         .peerNetworkId = nullptr,
204         .dataType = DATA_TYPE_MESSAGE,
205     };
206     info.peerNetworkId = OHOS::WaitOnLineAndGetNetWorkId();
207 
208     int32_t socket = Socket(info);
209     if (socket <= 0) {
210         LOGE("failed to create socket, ret=%d", socket);
211         return socket;
212     }
213     LOGI("create bind socket=%d", socket);
214 
215     QosTV qosInfo[] = {
216         { .qos = QOS_TYPE_MIN_BW,      .value = 80   },
217         { .qos = QOS_TYPE_MAX_LATENCY, .value = 4000 },
218         { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 },
219     };
220 
221     static ISocketListener listener = {
222         .OnMessage = OnMessage,
223         .OnShutdown = OnShutdown,
224     };
225 
226     int32_t ret = Bind(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
227     if (ret != SOFTBUS_OK) {
228         LOGE("failed to bind, socket=%d, ret=%d", socket, ret);
229         Shutdown(socket);
230         return ret;
231     }
232 
233     pkgName_ = pkgName;
234     myName_ = myName;
235     peerNetworkId_ = info.peerNetworkId;
236     peerName_ = peerName;
237     socket_ = socket;
238     return SOFTBUS_OK;
239 }
240 
OnBind(int32_t socket,PeerSocketInfo info)241 void TMessenger::OnBind(int32_t socket, PeerSocketInfo info)
242 {
243     TMessenger::GetInstance().SetConnectSocket(socket, info);
244 }
245 
OnMessage(int32_t socket,const void * data,uint32_t dataLen)246 void TMessenger::OnMessage(int32_t socket, const void *data, uint32_t dataLen)
247 {
248     std::string result(static_cast<const char *>(data), dataLen);
249     TMessenger::GetInstance().OnMessageRecv(result);
250 }
251 
OnShutdown(int32_t socket,ShutdownReason reason)252 void TMessenger::OnShutdown(int32_t socket, ShutdownReason reason)
253 {
254     TMessenger::GetInstance().CloseSocket(socket);
255 }
256 
SetConnectSocket(int32_t socket,PeerSocketInfo info)257 void TMessenger::SetConnectSocket(int32_t socket, PeerSocketInfo info)
258 {
259     if (socket_ > 0) {
260         return;
261     }
262 
263     socket_ = socket;
264     peerName_ = info.name;
265     peerNetworkId_ = info.networkId;
266 }
267 
OnMessageRecv(const std::string & result)268 void TMessenger::OnMessageRecv(const std::string &result)
269 {
270     std::shared_ptr<Message> msg = Message::Decode(result);
271     if (msg == nullptr) {
272         LOGE("receive invalid message");
273         return;
274     }
275 
276     switch (msg->msgType_) {
277         case Message::MsgType::MSG_SEQ: {
278             OnRequest();
279             break;
280         }
281         case Message::MsgType::MSG_RSP: {
282             std::unique_lock<std::mutex> lock(recvMutex_);
283             msgList_.push_back(msg);
284             lock.unlock();
285             recvCond_.notify_one();
286             break;
287         }
288         default:
289             break;
290     }
291 }
292 
OnRequest()293 void TMessenger::OnRequest()
294 {
295     std::thread t([this] {
296         std::this_thread::sleep_for(std::chrono::seconds(WAIT_RESP_TIME));
297         std::shared_ptr<Response> resp = onQuery_();
298         Message msg { *resp };
299         int ret = Send(msg);
300         if (ret != SOFTBUS_OK) {
301             LOGE("failed to send response");
302         }
303     });
304     t.detach();
305 }
306 
CloseSocket(int32_t socket)307 void TMessenger::CloseSocket(int32_t socket)
308 {
309     if (socket_ == socket) {
310         Shutdown(socket_);
311         socket_ = -1;
312     }
313 }
314 
QueryResult(uint32_t timeout)315 std::shared_ptr<Response> TMessenger::QueryResult(uint32_t timeout)
316 {
317     Request req { Request::Cmd::QUERY_RESULT };
318     Message msg { req };
319     int32_t ret = Send(msg);
320     if (ret != SOFTBUS_OK) {
321         LOGE("failed to query result, ret=%d", ret);
322         return nullptr;
323     }
324 
325     return WaitResponse(timeout);
326 }
327 
Send(const Message & msg)328 int32_t TMessenger::Send(const Message &msg)
329 {
330     std::string data = msg.Encode();
331     if (data.empty()) {
332         LOGE("the data is empty");
333         return SOFTBUS_MEM_ERR;
334     }
335 
336     int32_t ret = SendMessage(socket_, data.c_str(), data.size());
337     if (ret != SOFTBUS_OK) {
338         LOGE("failed to send message, socket=%d, ret=%d", socket_, ret);
339     }
340     return ret;
341 }
342 
WaitResponse(uint32_t timeout)343 std::shared_ptr<Response> TMessenger::WaitResponse(uint32_t timeout)
344 {
345     std::unique_lock<std::mutex> lock(recvMutex_);
346     std::shared_ptr<Response> rsp = nullptr;
347     if (recvCond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
348             rsp = GetMessageFromRecvList(Message::MsgType::MSG_RSP);
349             return rsp != nullptr;
350         })) {
351         return rsp;
352     }
353     LOGE("no result received");
354     return nullptr;
355 }
356 
GetMessageFromRecvList(Message::MsgType type)357 std::shared_ptr<Response> TMessenger::GetMessageFromRecvList(Message::MsgType type)
358 {
359     auto it = std::find_if(msgList_.begin(), msgList_.end(), [type] (const std::shared_ptr<Message> &it) {
360         return it->msgType_ == type;
361     });
362 
363     if (it == msgList_.end() || *it == nullptr) {
364         return nullptr;
365     }
366 
367     const Response *rsp = (*it)->response;
368     if (rsp == nullptr) {
369         msgList_.erase(it);
370         return nullptr;
371     }
372 
373     std::shared_ptr<Response> resp = std::make_shared<Response>(*rsp);
374     msgList_.erase(it);
375     return resp;
376 }
377 
RegisterOnQuery(TMessenger::OnQueryCallback callback)378 void TMessenger::RegisterOnQuery(TMessenger::OnQueryCallback callback)
379 {
380     onQuery_ = callback;
381 }
382 } // namespace OHOS
383