1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tmessenger.h"
17
18 #include <algorithm>
19 #include <chrono>
20 #include <cinttypes>
21 #include <thread>
22
23 #include "common.h"
24
25 namespace OHOS {
26 static constexpr uint32_t WAIT_RESP_TIME = 1;
27
Encode() const28 std::string Request::Encode() const
29 {
30 return std::to_string(static_cast<int32_t>(cmd_));
31 }
32
Decode(const std::string & data)33 std::shared_ptr<Request> Request::Decode(const std::string &data)
34 {
35 if (data.empty()) {
36 LOGE("the data is empty");
37 return nullptr;
38 }
39
40 Cmd cmd = static_cast<Cmd>(std::stoi(data));
41 if (cmd < Cmd::QUERY_RESULT || cmd > Cmd::QUERY_RESULT) {
42 LOGE("invalid cmd=%d", static_cast<int32_t>(cmd));
43 return nullptr;
44 }
45 return std::make_shared<Request>(cmd);
46 }
47
Encode() const48 std::string Response::Encode() const
49 {
50 std::string data = std::to_string(isEncrypt_ ? 1 : 0);
51 return data + SEPARATOR + recvData_;
52 }
53
Decode(const std::string & data)54 std::shared_ptr<Response> Response::Decode(const std::string &data)
55 {
56 if (data.empty()) {
57 LOGE("the data is empty");
58 return nullptr;
59 }
60
61 size_t pos = data.find(SEPARATOR);
62 if (pos == std::string::npos) {
63 LOGE("can not find separator in the string data");
64 return nullptr;
65 }
66
67 int32_t isEncryptVal = static_cast<int32_t>(std::stoi(data.substr(0, pos)));
68 bool isEncrypt = (isEncryptVal == 1);
69 std::string recvData = data.substr(pos + 1);
70
71 return std::make_shared<Response>(isEncrypt, recvData);
72 }
73
~Message()74 Message::~Message()
75 {
76 if (msgType_ == MsgType::MSG_SEQ && request != nullptr) {
77 delete request;
78 }
79 if (msgType_ == MsgType::MSG_RSP && response != nullptr) {
80 delete response;
81 }
82 }
83
Encode() const84 std::string Message::Encode() const
85 {
86 std::string data = std::to_string(static_cast<int32_t>(msgType_));
87 switch (msgType_) {
88 case MsgType::MSG_SEQ:
89 return request == nullptr ? "" : data + SEPARATOR + request->Encode();
90 case MsgType::MSG_RSP:
91 return response == nullptr ? "" : data + SEPARATOR + response->Encode();
92 default:
93 LOGE("invalid msgType=%d", static_cast<int32_t>(msgType_));
94 return "";
95 }
96 }
97
Decode(const std::string & data)98 std::shared_ptr<Message> Message::Decode(const std::string &data)
99 {
100 size_t pos = data.find(SEPARATOR);
101 if (pos == std::string::npos) {
102 return nullptr;
103 }
104
105 MsgType msgType = static_cast<MsgType>(std::stoi(data.substr(0, pos)));
106 switch (msgType) {
107 case MsgType::MSG_SEQ: {
108 std::shared_ptr<Request> req = Request::Decode(data.substr(pos + 1));
109 if (req == nullptr) {
110 return nullptr;
111 }
112 return std::make_shared<Message>(*req);
113 }
114 case MsgType::MSG_RSP: {
115 std::shared_ptr<Response> rsp = Response::Decode(data.substr(pos + 1));
116 if (rsp == nullptr) {
117 return nullptr;
118 }
119 return std::make_shared<Message>(*rsp);
120 }
121 default:
122 LOGE("invalid msgType=%d", static_cast<int32_t>(msgType));
123 return nullptr;
124 }
125 }
126
Open(const std::string & pkgName,const std::string & myName,const std::string & peerName,bool isServer)127 int32_t TMessenger::Open(
128 const std::string &pkgName, const std::string &myName, const std::string &peerName, bool isServer)
129 {
130 isServer_ = isServer;
131 return isServer_ ? StartListen(pkgName, myName) : StartConnect(pkgName, myName, peerName);
132 }
133
Close()134 void TMessenger::Close()
135 {
136 if (socket_ > 0) {
137 Shutdown(socket_);
138 socket_ = -1;
139 }
140
141 if (listenSocket_ > 0) {
142 Shutdown(listenSocket_);
143 listenSocket_ = -1;
144 }
145
146 pkgName_.clear();
147 myName_.clear();
148 peerName_.clear();
149 peerNetworkId_.clear();
150 msgList_.clear();
151 }
152
StartListen(const std::string & pkgName,const std::string & myName)153 int32_t TMessenger::StartListen(const std::string &pkgName, const std::string &myName)
154 {
155 if (listenSocket_ > 0) {
156 return SOFTBUS_OK;
157 }
158
159 SocketInfo info = {
160 .pkgName = (char *)(pkgName.c_str()),
161 .name = (char *)(myName.c_str()),
162 };
163 int32_t socket = Socket(info);
164 if (socket <= 0) {
165 LOGE("failed to create socket, ret=%d", socket);
166 return socket;
167 }
168 LOGI("create listen socket=%d", socket);
169
170 QosTV qosInfo[] = {
171 { .qos = QOS_TYPE_MIN_BW, .value = 80 },
172 { .qos = QOS_TYPE_MAX_LATENCY, .value = 4000 },
173 { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 },
174 };
175 static ISocketListener listener = {
176 .OnBind = TMessenger::OnBind,
177 .OnMessage = TMessenger::OnMessage,
178 .OnShutdown = TMessenger::OnShutdown,
179 };
180
181 int32_t ret = Listen(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
182 if (ret != SOFTBUS_OK) {
183 LOGE("failed to listen, socket=%d", socket);
184 Shutdown(socket);
185 return ret;
186 }
187 listenSocket_ = socket;
188 pkgName_ = pkgName;
189 myName_ = myName;
190 return SOFTBUS_OK;
191 }
192
StartConnect(const std::string & pkgName,const std::string & myName,const std::string & peerName)193 int32_t TMessenger::StartConnect(const std::string &pkgName, const std::string &myName, const std::string &peerName)
194 {
195 if (socket_ > 0) {
196 return SOFTBUS_OK;
197 }
198
199 SocketInfo info = {
200 .pkgName = const_cast<char *>(pkgName.c_str()),
201 .name = const_cast<char *>(myName.c_str()),
202 .peerName = const_cast<char *>(peerName.c_str()),
203 .peerNetworkId = nullptr,
204 .dataType = DATA_TYPE_MESSAGE,
205 };
206 info.peerNetworkId = OHOS::WaitOnLineAndGetNetWorkId();
207
208 int32_t socket = Socket(info);
209 if (socket <= 0) {
210 LOGE("failed to create socket, ret=%d", socket);
211 return socket;
212 }
213 LOGI("create bind socket=%d", socket);
214
215 QosTV qosInfo[] = {
216 { .qos = QOS_TYPE_MIN_BW, .value = 80 },
217 { .qos = QOS_TYPE_MAX_LATENCY, .value = 4000 },
218 { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 },
219 };
220
221 static ISocketListener listener = {
222 .OnMessage = OnMessage,
223 .OnShutdown = OnShutdown,
224 };
225
226 int32_t ret = Bind(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
227 if (ret != SOFTBUS_OK) {
228 LOGE("failed to bind, socket=%d, ret=%d", socket, ret);
229 Shutdown(socket);
230 return ret;
231 }
232
233 pkgName_ = pkgName;
234 myName_ = myName;
235 peerNetworkId_ = info.peerNetworkId;
236 peerName_ = peerName;
237 socket_ = socket;
238 return SOFTBUS_OK;
239 }
240
OnBind(int32_t socket,PeerSocketInfo info)241 void TMessenger::OnBind(int32_t socket, PeerSocketInfo info)
242 {
243 TMessenger::GetInstance().SetConnectSocket(socket, info);
244 }
245
OnMessage(int32_t socket,const void * data,uint32_t dataLen)246 void TMessenger::OnMessage(int32_t socket, const void *data, uint32_t dataLen)
247 {
248 std::string result(static_cast<const char *>(data), dataLen);
249 TMessenger::GetInstance().OnMessageRecv(result);
250 }
251
OnShutdown(int32_t socket,ShutdownReason reason)252 void TMessenger::OnShutdown(int32_t socket, ShutdownReason reason)
253 {
254 TMessenger::GetInstance().CloseSocket(socket);
255 }
256
SetConnectSocket(int32_t socket,PeerSocketInfo info)257 void TMessenger::SetConnectSocket(int32_t socket, PeerSocketInfo info)
258 {
259 if (socket_ > 0) {
260 return;
261 }
262
263 socket_ = socket;
264 peerName_ = info.name;
265 peerNetworkId_ = info.networkId;
266 }
267
OnMessageRecv(const std::string & result)268 void TMessenger::OnMessageRecv(const std::string &result)
269 {
270 std::shared_ptr<Message> msg = Message::Decode(result);
271 if (msg == nullptr) {
272 LOGE("receive invalid message");
273 return;
274 }
275
276 switch (msg->msgType_) {
277 case Message::MsgType::MSG_SEQ: {
278 OnRequest();
279 break;
280 }
281 case Message::MsgType::MSG_RSP: {
282 std::unique_lock<std::mutex> lock(recvMutex_);
283 msgList_.push_back(msg);
284 lock.unlock();
285 recvCond_.notify_one();
286 break;
287 }
288 default:
289 break;
290 }
291 }
292
OnRequest()293 void TMessenger::OnRequest()
294 {
295 std::thread t([this] {
296 std::this_thread::sleep_for(std::chrono::seconds(WAIT_RESP_TIME));
297 std::shared_ptr<Response> resp = onQuery_();
298 Message msg { *resp };
299 int ret = Send(msg);
300 if (ret != SOFTBUS_OK) {
301 LOGE("failed to send response");
302 }
303 });
304 t.detach();
305 }
306
CloseSocket(int32_t socket)307 void TMessenger::CloseSocket(int32_t socket)
308 {
309 if (socket_ == socket) {
310 Shutdown(socket_);
311 socket_ = -1;
312 }
313 }
314
QueryResult(uint32_t timeout)315 std::shared_ptr<Response> TMessenger::QueryResult(uint32_t timeout)
316 {
317 Request req { Request::Cmd::QUERY_RESULT };
318 Message msg { req };
319 int32_t ret = Send(msg);
320 if (ret != SOFTBUS_OK) {
321 LOGE("failed to query result, ret=%d", ret);
322 return nullptr;
323 }
324
325 return WaitResponse(timeout);
326 }
327
Send(const Message & msg)328 int32_t TMessenger::Send(const Message &msg)
329 {
330 std::string data = msg.Encode();
331 if (data.empty()) {
332 LOGE("the data is empty");
333 return SOFTBUS_MEM_ERR;
334 }
335
336 int32_t ret = SendMessage(socket_, data.c_str(), data.size());
337 if (ret != SOFTBUS_OK) {
338 LOGE("failed to send message, socket=%d, ret=%d", socket_, ret);
339 }
340 return ret;
341 }
342
WaitResponse(uint32_t timeout)343 std::shared_ptr<Response> TMessenger::WaitResponse(uint32_t timeout)
344 {
345 std::unique_lock<std::mutex> lock(recvMutex_);
346 std::shared_ptr<Response> rsp = nullptr;
347 if (recvCond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
348 rsp = GetMessageFromRecvList(Message::MsgType::MSG_RSP);
349 return rsp != nullptr;
350 })) {
351 return rsp;
352 }
353 LOGE("no result received");
354 return nullptr;
355 }
356
GetMessageFromRecvList(Message::MsgType type)357 std::shared_ptr<Response> TMessenger::GetMessageFromRecvList(Message::MsgType type)
358 {
359 auto it = std::find_if(msgList_.begin(), msgList_.end(), [type] (const std::shared_ptr<Message> &it) {
360 return it->msgType_ == type;
361 });
362
363 if (it == msgList_.end() || *it == nullptr) {
364 return nullptr;
365 }
366
367 const Response *rsp = (*it)->response;
368 if (rsp == nullptr) {
369 msgList_.erase(it);
370 return nullptr;
371 }
372
373 std::shared_ptr<Response> resp = std::make_shared<Response>(*rsp);
374 msgList_.erase(it);
375 return resp;
376 }
377
RegisterOnQuery(TMessenger::OnQueryCallback callback)378 void TMessenger::RegisterOnQuery(TMessenger::OnQueryCallback callback)
379 {
380 onQuery_ = callback;
381 }
382 } // namespace OHOS
383