1 /*
2  * Copyright (C) 2021-2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "socket.h"
17 #include <sys/ioctl.h>
18 #include <sys/socket.h>
19 #include <unistd.h>
20 
21 #include "adapter_config.h"
22 #include "log.h"
23 #include "packet.h"
24 #include "power_manager.h"
25 #include "profile_service_manager.h"
26 #include "securec.h"
27 
28 #include "socket_def.h"
29 #include "socket_listener.h"
30 #include "socket_service.h"
31 #include "socket_util.h"
32 
33 namespace OHOS {
34 namespace bluetooth {
35 static int g_arrayServiceId[SOCK_MAX_SERVICE_ID] = {0};
36 std::vector<Socket *> Socket::g_allServerSockets;
37 std::recursive_mutex Socket::g_socketMutex;
38 
39 struct Socket::impl {
40     class DataTransportObserverImplement;
41     std::unique_ptr<DataTransportObserver> transportObserver_ {};
42     void OnConnectIncomingNative(Socket &socket, RawAddress addr, uint8_t port);
43     void OnConnectedNative(Socket &socket, DataTransport *transport, uint16_t sendMTU, uint16_t recvMTU);
44     void OnDisconnectedNative(Socket &socket, DataTransport *transport);
45     void OnDisconnectSuccessNative(Socket &socket, DataTransport *transport);
46     void OnDataAvailableNative(Socket &socket, DataTransport *transport);
47     void OnTransportErrorNative(Socket &socket, DataTransport *transport, int errType);
48     void SockRfcConnectFail(Socket &socket, DataTransport *transport);
49     void SockRfcDisconnectFail(Socket &socket, DataTransport *transport);
50     void SockRfcFcOn(Socket &socket, DataTransport *transport);
51     static int GetMaxConnectionDevicesNum();
52 };
53 
54 class Socket::impl::DataTransportObserverImplement : public DataTransportObserver {
55 public:
OnConnectIncoming(const RawAddress & addr,uint16_t port)56     void OnConnectIncoming(const RawAddress &addr, uint16_t port) override
57     {
58         SocketService *socketService =
59             static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
60         if (socketService != nullptr) {
61             socketService->GetDispatcher()->PostTask(
62                 std::bind(&impl::OnConnectIncomingNative, socket_.pimpl.get(), std::ref(socket_), addr, port));
63         }
64     }
65 
OnIncomingDisconnected(const RawAddress & addr)66     void OnIncomingDisconnected(const RawAddress &addr) override
67     {
68         LOG_INFO("[sock]%{public}s", __func__);
69     }
70 
OnConnected(DataTransport * transport,uint16_t sendMTU,uint16_t recvMTU)71     void OnConnected(DataTransport *transport, uint16_t sendMTU, uint16_t recvMTU) override
72     {
73         LOG_INFO("[sock]%{public}s", __func__);
74 
75         SocketService *socketService =
76             static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
77         if (socketService != nullptr) {
78             socketService->GetDispatcher()->PostTask(std::bind(&impl::OnConnectedNative, socket_.pimpl.get(),
79                                                                std::ref(socket_), transport, sendMTU, recvMTU));
80         }
81     }
82 
OnDisconnected(DataTransport * transport)83     void OnDisconnected(DataTransport *transport) override
84     {
85         LOG_INFO("[sock]%{public}s", __func__);
86 
87         SocketService *socketService =
88             static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
89         if (socketService != nullptr) {
90             socketService->GetDispatcher()->PostTask(
91                 std::bind(&impl::OnDisconnectedNative, socket_.pimpl.get(), std::ref(socket_), transport));
92         }
93     }
94 
OnDisconnectSuccess(DataTransport * transport)95     void OnDisconnectSuccess(DataTransport *transport) override
96     {
97         LOG_INFO("[sock]%{public}s", __func__);
98 
99         SocketService *socketService =
100             static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
101         if (socketService != nullptr) {
102             socketService->GetDispatcher()->PostTask(
103                 std::bind(&impl::OnDisconnectSuccessNative, socket_.pimpl.get(), std::ref(socket_), transport));
104         }
105     }
106 
OnDataAvailable(DataTransport * transport)107     void OnDataAvailable(DataTransport *transport) override
108     {
109         LOG_INFO("[sock]%{public}s", __func__);
110 
111         SocketService *socketService =
112             static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
113         if (socketService != nullptr) {
114             socketService->GetDispatcher()->PostTask(
115                 std::bind(&impl::OnDataAvailableNative, socket_.pimpl.get(), std::ref(socket_), transport));
116         }
117     }
118 
OnDataAvailable(DataTransport * transport,Packet * pkt)119     void OnDataAvailable(DataTransport *transport, Packet *pkt) override
120     {
121         LOG_INFO("[sock]%{public}s", __func__);
122     }
123 
OnTransportError(DataTransport * transport,int errType)124     void OnTransportError(DataTransport *transport, int errType) override
125     {
126         LOG_INFO("[sock]%{public}s", __func__);
127 
128         SocketService *socketService =
129             static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
130         if (socketService != nullptr) {
131             socketService->GetDispatcher()->PostTask(
132                 std::bind(&impl::OnTransportErrorNative, socket_.pimpl.get(), std::ref(socket_), transport, errType));
133         }
134     }
135 
DataTransportObserverImplement(Socket & socket)136     DataTransportObserverImplement(Socket &socket) : socket_(socket)
137     {}
138 
139     ~DataTransportObserverImplement() = default;
140 
141 private:
142     Socket &socket_;
143 };
144 
OnConnectIncomingNative(Socket & socket,RawAddress addr,uint8_t port)145 void Socket::impl::OnConnectIncomingNative(Socket &socket, RawAddress addr, uint8_t port)
146 {
147     LOG_INFO("[sock]%{public}s", __func__);
148 
149     socket.maxConnectedNum_ = GetMaxConnectionDevicesNum();
150     addr.ConvertToUint8(socket.remoteAddr_.addr);
151     socket.remoteAddr_.type = BT_PUBLIC_DEVICE_ADDRESS;
152     if (socket.clientNumber_ < socket.maxConnectedNum_) {
153         socket.sockTransport_->AcceptConnection(addr, port);
154     } else {
155         socket.sockTransport_->RejectConnection(addr, port);
156     }
157 }
158 
OnConnectedNative(Socket & socket,DataTransport * transport,uint16_t sendMTU,uint16_t recvMTU)159 void Socket::impl::OnConnectedNative(Socket &socket, DataTransport *transport, uint16_t sendMTU, uint16_t recvMTU)
160 {
161     LOG_INFO("[sock]%{public}s", __func__);
162     IPowerManager::GetInstance().StatusUpdate(
163         RequestStatus::CONNECT_ON, PROFILE_NAME_SPP, RawAddress::ConvertToString(socket.remoteAddr_.addr));
164     SocketConnectInfo connectInfo;
165     (void)memset_s(&connectInfo, sizeof(connectInfo), 0, sizeof(connectInfo));
166     (void)memcpy_s(connectInfo.addr, sizeof(connectInfo.addr), socket.remoteAddr_.addr,
167         sizeof(socket.remoteAddr_.addr));
168     connectInfo.status = true;
169     connectInfo.txMtu = sendMTU;
170     connectInfo.rxMtu = recvMTU;
171     if (socket.IsServer()) {
172         socket.clientNumber_++;
173         int newFd = socket.AddSocketInternal(socket.remoteAddr_, transport, sendMTU, recvMTU);
174         Socket::SendAppConnectInfo(socket.transportFd_, newFd, connectInfo);
175     } else {
176         socket.state_ = CONNECTED;
177         socket.sendMTU_ = sendMTU;
178         socket.recvMTU_ = recvMTU;
179         Socket::SendAppConnectInfo(socket.transportFd_, -1, connectInfo);
180         LOG_INFO("[sock]%{public}s app fd:%{public}d client connect successfully", __func__, socket.upperlayerFd_);
181         std::lock_guard<std::recursive_mutex> lck(Socket::g_socketMutex);
182         g_allServerSockets.push_back(&socket);
183     }
184 }
185 
OnDisconnectedNative(Socket & socket,DataTransport * transport)186 void Socket::impl::OnDisconnectedNative(Socket &socket, DataTransport *transport)
187 {
188     LOG_INFO("[sock]%{public}s", __func__);
189 
190     socket.ProcessDisconnection(socket, transport);
191 }
192 
OnDisconnectSuccessNative(Socket & socket,DataTransport * transport)193 void Socket::impl::OnDisconnectSuccessNative(Socket &socket, DataTransport *transport)
194 {
195     LOG_INFO("[sock]%{public}s", __func__);
196 
197     socket.ProcessDisconnection(socket, transport);
198 }
199 
OnDataAvailableNative(Socket & socket,DataTransport * transport)200 void Socket::impl::OnDataAvailableNative(Socket &socket, DataTransport *transport)
201 {
202     LOG_INFO("[sock]%{public}s", __func__);
203 
204     Packet *pkt = nullptr;
205     uint8_t *pData = nullptr;
206     Buffer *buf = nullptr;
207 
208     Socket *socketTmp = nullptr;
209     if (socket.IsServer()) {
210         if (socket.socketMap_.find(transport) != socket.socketMap_.end()) {
211             socketTmp = socket.socketMap_.at(transport).get();
212         } else {
213             LOG_ERROR("[sock]%{public}s socket does not exist", __func__);
214             return;
215         }
216     } else {
217         socketTmp = &socket;
218     }
219 
220     if (!socketTmp->isCanRead_) {
221         LOG_DEBUG("[sock]%{public}s app can not receive data", __func__);
222         return;
223     }
224 
225     if (socketTmp->isNewSocket_) {
226         if (socketTmp->newSockTransport_ != nullptr) {
227             socketTmp->newSockTransport_->Read(&pkt);
228         } else {
229             LOG_DEBUG("[sock]%{public}s newSockTransport is null", __func__);
230             return;
231         }
232     } else {
233         socketTmp->sockTransport_->Read(&pkt);
234     }
235 
236     if (pkt == nullptr) {
237         LOG_ERROR("[sock]%{public}s pkt is null", __func__);
238         return;
239     }
240 
241     size_t len = PacketPayloadSize(pkt);
242     buf = PacketContinuousPayload(pkt);
243     if (buf == nullptr) {
244         LOG_ERROR("[sock]%{public}s pkt buf is null", __func__);
245         return;
246     }
247     pData = (uint8_t *)BufferPtr(buf);
248 
249     socketTmp->WriteDataToAPP(pData, len);
250 
251     if (pkt != nullptr) {
252         PacketFree(pkt);
253     }
254 }
255 
OnTransportErrorNative(Socket & socket,DataTransport * transport,int errType)256 void Socket::impl::OnTransportErrorNative(Socket &socket, DataTransport *transport, int errType)
257 {
258     LOG_INFO("[sock]%{public}s errType:%{public}d", __func__, errType);
259 
260     switch (errType) {
261         case RFCOMM_CONNECT_FAIL:
262             SockRfcConnectFail(socket, transport);
263             break;
264         case RFCOMM_DISCONNECT_FAIL:
265             SockRfcDisconnectFail(socket, transport);
266             break;
267         case RFCOMM_EV_FC_ON:
268             SockRfcFcOn(socket, transport);
269             break;
270         default:
271             break;
272     }
273 }
274 
SockRfcConnectFail(Socket & socket,DataTransport * transport)275 void Socket::impl::SockRfcConnectFail(Socket &socket, DataTransport *transport)
276 {
277     LOG_INFO("[sock]%{public}s", __func__);
278     SocketConnectInfo connectInfo;
279     (void)memset_s(&connectInfo, sizeof(connectInfo), 0, sizeof(connectInfo));
280     (void)memcpy_s(connectInfo.addr, sizeof(connectInfo.addr), socket.remoteAddr_.addr,
281         sizeof(socket.remoteAddr_.addr));
282     connectInfo.status = false;
283     connectInfo.txMtu = 0;
284     connectInfo.rxMtu = 0;
285     if (socket.IsServer()) {
286         if (socket.socketMap_.find(transport) != socket.socketMap_.end()) {
287             Socket *serverSocket = nullptr;
288             serverSocket = socket.socketMap_.at(transport).get();
289             Socket::SendAppConnectInfo(serverSocket->transportFd_, -1, connectInfo);
290         }
291     } else {
292         Socket::SendAppConnectInfo(socket.transportFd_, -1, connectInfo);
293     }
294     socket.ProcessDisconnection(socket, transport);
295 }
296 
SockRfcDisconnectFail(Socket & socket,DataTransport * transport)297 void Socket::impl::SockRfcDisconnectFail(Socket &socket, DataTransport *transport)
298 {
299     LOG_INFO("[sock]%{public}s", __func__);
300 
301     if (socket.IsServer()) {
302         if (socket.socketMap_.find(transport) != socket.socketMap_.end()) {
303             LOG_DEBUG("SockRfcDisconnectFail closefd : fd:%{public}d",
304                 socket.socketMap_.at(transport).get()->transportFd_);
305             close(socket.socketMap_.at(transport).get()->transportFd_);
306         } else {
307             LOG_ERROR("[sock]socket does not exist");
308         }
309     } else {
310         LOG_DEBUG("SockRfcDisconnectFail closefd : fd:%{public}d", socket.transportFd_);
311         close(socket.transportFd_);
312     }
313 }
314 
SockRfcFcOn(Socket & socket,DataTransport * transport)315 void Socket::impl::SockRfcFcOn(Socket &socket, DataTransport *transport)
316 {
317     LOG_INFO("[sock]%{public}s", __func__);
318 
319     Socket *socketTmp = nullptr;
320     if (socket.IsServer()) {
321         if (socket.socketMap_.find(transport) != socket.socketMap_.end()) {
322             socketTmp = socket.socketMap_.at(transport).get();
323         } else {
324             LOG_ERROR("socket does not exist");
325         }
326     } else {
327         socketTmp = &socket;
328     }
329     if (socketTmp == nullptr) {
330         return;
331     }
332     std::lock_guard<std::recursive_mutex> lk(socketTmp->writeMutex_);
333     if (socketTmp->sendBufLen_ > 0) {
334         Packet *wPkt = PacketMalloc(0, 0, socketTmp->sendBufLen_);
335         Buffer *wPayloadBuf = PacketContinuousPayload(wPkt);
336         void *buffer = BufferPtr(wPayloadBuf);
337         (void)memcpy_s(buffer, socketTmp->sendBufLen_, socketTmp->sendDataBuf_, socketTmp->sendBufLen_);
338         if (wPayloadBuf == nullptr) {
339             if (wPkt != nullptr) {
340                 PacketFree(wPkt);
341             }
342             return;
343         }
344         int ret = socketTmp->TransportWrite(wPkt);
345         if (ret < 0) {
346             LOG_ERROR("%{public}s stack write failed", __func__);
347         } else {
348             (void)memset_s(socketTmp->sendDataBuf_, socketTmp->sendBufLen_, 0x00, socketTmp->sendBufLen_);
349             socketTmp->sendBufLen_ = 0;
350             socketTmp->isCanWrite_ = true;
351             socketTmp->WriteData();
352         }
353         if (wPkt != nullptr) {
354             PacketFree(wPkt);
355         }
356     } else {
357         socketTmp->isCanWrite_ = true;
358         socketTmp->WriteData();
359     }
360 }
361 
GetMaxConnectionDevicesNum()362 int Socket::impl::GetMaxConnectionDevicesNum()
363 {
364     int number = SOCK_MAX_CLIENT;
365     if (!AdapterConfig::GetInstance()->GetValue(SECTION_SOCKET_SERVICE, PROPERTY_MAX_CONNECTED_DEVICES, number)) {
366         LOG_DEBUG("[sock]%{public}s: It's failed to get the max connection number", __FUNCTION__);
367     }
368     return number;
369 }
370 
Socket()371 Socket::Socket() : pimpl(nullptr)
372 {
373     state_ = SocketState::INIT;
374     pimpl = std::make_unique<Socket::impl>();
375     this->pimpl->transportObserver_ = std::make_unique<Socket::impl::DataTransportObserverImplement>(*this);
376 }
377 
~Socket()378 Socket::~Socket()
379 {}
380 
Connect(const std::string & addr,const Uuid & uuid,int securityFlag,int & sockfd)381 int Socket::Connect(const std::string &addr, const Uuid &uuid, int securityFlag, int &sockfd)
382 {
383     LOG_INFO("[sock]%{public}s", __func__);
384 
385     sockfd = SOCK_INVALID_FD;
386     int socketPair[2] = {SOCK_INVALID_FD, SOCK_INVALID_FD};
387 
388     if (socketpair(AF_UNIX, SOCK_STREAM, 0, socketPair) == -1) {
389         LOG_ERROR("[sock]%{public}s: create rfcomm socket pair failed", __FUNCTION__);
390         return -1;
391     }
392 
393     SetRemoteAddr(addr);
394 
395     sdpClient_ = std::make_unique<SocketSdpClient>();
396     int ret = sdpClient_->StartDiscovery(addr, uuid, this);
397     if (ret != BT_SUCCESS) {
398         LOG_ERROR("[sock]%{public}s: Discovery SPP Service Fail!", __FUNCTION__);
399     }
400 
401     upperlayerFd_ = socketPair[0];
402     transportFd_ = socketPair[1];
403 
404     LOG_INFO("[sock]%{public}s appFd:%{public}d fd:%{public}d", __func__, upperlayerFd_, transportFd_);
405 
406     sockfd = upperlayerFd_;
407     upperlayerFd_ = SOCK_INVALID_FD;
408     securityFlag_ = securityFlag;
409     SocketThread::GetInstance().AddSocket(transportFd_, 0, *this);
410     return ret;
411 }
412 
Listen(const std::string & name,const Uuid & uuid,int securityFlag,int & sockfd)413 int Socket::Listen(const std::string &name, const Uuid &uuid, int securityFlag, int &sockfd)
414 {
415     LOG_INFO("[sock]%{public}s", __func__);
416 
417     isServer_ = true;
418     sockfd = SOCK_INVALID_FD;
419     int socketPair[2] = {SOCK_INVALID_FD, SOCK_INVALID_FD};
420     if (socketpair(AF_UNIX, SOCK_STREAM, 0, socketPair) == -1) {
421         LOG_ERROR("[sock]%{public}s: create listen socket failed", __FUNCTION__);
422         return -1;
423     }
424     upperlayerFd_ = socketPair[0];
425     transportFd_ = socketPair[1];
426     LOG_INFO("[sock]%{public}s appFd:%{public}d fd:%{public}d", __func__, upperlayerFd_, transportFd_);
427     sockfd = upperlayerFd_;
428     upperlayerFd_ = SOCK_INVALID_FD;
429     securityFlag_ = securityFlag;
430     state_ = LISTEN;
431 
432     SocketThread::GetInstance().AddSocket(transportFd_, 0, *this);
433 
434     scn_ = RFCOMM_AssignServerNum();
435 
436     sdpServer_ = std::make_unique<SocketSdpServer>();
437     int ret = sdpServer_->RegisterSdpService(name, uuid, scn_);
438     if (ret != BT_SUCCESS) {
439         LOG_ERROR("[sock]%{public}s: Discovery SPP Service Fail!", __FUNCTION__);
440     }
441 
442     if (!SendAppConnectScn(transportFd_, scn_)) {
443         LOG_ERROR("send scn failed");
444         CloseSocketFd();
445         return -1;
446     }
447 
448     serviceId_ = AssignServiceId();
449     LOG_INFO("[sock]%{public}s securityFlag:%{public}d serviceId_:%{public}d", __func__, securityFlag_, serviceId_);
450     socketGapServer_ = std::make_unique<SocketGapServer>();
451     socketGapServer_->RegisterServiceSecurity(scn_, securityFlag_, serviceId_);
452 
453     sockTransport_ = std::move(transportFactory_->CreateRfcommTransport(
454         nullptr, scn_, SOCK_DEF_RFC_MTU, *this->pimpl->transportObserver_.get(), *GetDispatchter()));
455     sockTransport_->RegisterServer();
456 
457     std::lock_guard<std::recursive_mutex> lk(Socket::g_socketMutex);
458     g_allServerSockets.push_back(this);
459     return ret;
460 }
461 
ReceiveSdpResult(uint8_t scn)462 int Socket::ReceiveSdpResult(uint8_t scn)
463 {
464     LOG_INFO("[sock]%{public}s", __func__);
465 
466     serviceId_ = AssignServiceId();
467     scn_ = scn;
468     LOG_INFO("[sock]%{public}s securityFlag:%{public}d serviceId_:%{public}d scn:%hhu",
469         __func__, securityFlag_, serviceId_, scn_);
470     if (scn_ > SOCK_MAX_SERVER) {
471         LOG_INFO("[sock]%{public}s scn invalid", __func__);
472         return -1;
473     }
474 
475     socketGapClient_ = std::make_unique<SocketGapClient>();
476     socketGapClient_->RegisterServiceSecurity(remoteAddr_, scn_, securityFlag_, serviceId_);
477     RawAddress rawAddr = RawAddress::ConvertToString(remoteAddr_.addr);
478     sockTransport_ = std::move(transportFactory_->CreateRfcommTransport(
479         &rawAddr, scn_, SOCK_DEF_RFC_MTU, *this->pimpl->transportObserver_.get(), *GetDispatchter()));
480 
481     if (!SendAppConnectScn(transportFd_, scn_)) {
482         LOG_ERROR("send scn failed");
483         CloseSocketFd();
484         return -1;
485     }
486 
487     switch (state_) {
488         case INIT:
489             if (sockTransport_->Connect() < 0) {
490                 LOG_ERROR("[sock]create rfcomm channel failed");
491                 SocketThread::GetInstance().DeleteSocket(*this);
492                 CloseSocketFd();
493                 return -1;
494             }
495             state_ = CONNECTING;
496             break;
497         default:
498             LOG_ERROR("[sock]create rfcomm channel failed");
499             break;
500     }
501     return 0;
502 }
503 
AddSocketInternal(BtAddr addr,DataTransport * transport,uint16_t sendMTU,uint16_t recvMTU)504 int Socket::AddSocketInternal(BtAddr addr, DataTransport *transport, uint16_t sendMTU, uint16_t recvMTU)
505 {
506     LOG_INFO("[sock]%{public}s", __func__);
507 
508     std::unique_ptr<Socket> acceptSocket = std::make_unique<Socket>();
509     int socketPair[2] = {SOCK_INVALID_FD, SOCK_INVALID_FD};
510     if (socketpair(AF_UNIX, SOCK_STREAM, 0, socketPair) == -1) {
511         LOG_ERROR("[sock]create accept socket failed");
512     }
513     LOG_DEBUG("AddSocketInternal : fd:%{public}d, fd:%{public}d", socketPair[0], socketPair[1]);
514     acceptSocket->upperlayerFd_ = socketPair[0];
515     acceptSocket->transportFd_ = socketPair[1];
516     acceptSocket->remoteAddr_ = addr;
517     acceptSocket->isNewSocket_ = true;
518     acceptSocket->isServer_ = true;
519     acceptSocket->state_ = SocketState::CONNECTED;
520     acceptSocket->sendMTU_ = sendMTU;
521     acceptSocket->recvMTU_ = recvMTU;
522     acceptSocket->newSockTransport_ = transport;
523     mutex_.lock();
524     auto it = socketMap_.emplace(transport, std::move(acceptSocket));
525     mutex_.unlock();
526 
527     SocketThread::GetInstance().AddSocket(
528         it.first->second.get()->transportFd_, 0, *(it.first->second.get()));
529 
530     utility::Message msg(SOCKET_ACCEPT_NEW);
531     msg.arg1_ = socketPair[0];
532     msg.arg2_ = it.first->second.get();
533     SocketService *socketService =
534         static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
535     if (socketService != nullptr) {
536             socketService->ProcessMessage(msg);
537     }
538 
539     std::lock_guard<std::recursive_mutex> lck(Socket::g_socketMutex);
540     g_allServerSockets.push_back(it.first->second.get());
541 
542     if (socketPair[0] < 0) {
543         LOG_ERROR("[sock]create accept socket failed");
544     }
545     return socketPair[0];
546 }
547 
SendAppConnectScn(int fd,int scn)548 bool Socket::SendAppConnectScn(int fd, int scn)
549 {
550     return SocketUtil::SocketSendData(fd, reinterpret_cast<const uint8_t *>(&scn), sizeof(scn));
551 }
552 
SendAppConnectInfo(int fd,int acceptFd,const SocketConnectInfo & connectInfo)553 bool Socket::SendAppConnectInfo(int fd, int acceptFd, const SocketConnectInfo &connectInfo)
554 {
555     LOG_INFO("[sock]%{public}s", __func__);
556     LOG_INFO("[sock]%{public}s size:%{public}zu", __func__, sizeof(connectInfo));
557     if (acceptFd == -1) {
558         return SocketUtil::SocketSendData(fd, reinterpret_cast<const uint8_t *>(&connectInfo), sizeof(connectInfo));
559     } else {
560         return SocketUtil::SocketSendFd(fd, reinterpret_cast<const uint8_t *>(&connectInfo),
561                                         sizeof(connectInfo), acceptFd);
562     }
563 }
564 
ProcessDisconnection(Socket & socket,DataTransport * transport)565 void Socket::ProcessDisconnection(Socket &socket, DataTransport *transport)
566 {
567     LOG_INFO("[sock]%{public}s", __func__);
568 
569     IPowerManager::GetInstance().StatusUpdate(
570         RequestStatus::CONNECT_OFF, PROFILE_NAME_SPP, RawAddress::ConvertToString(socket.remoteAddr_.addr));
571 
572     if (socket.IsServer()) {
573         if (socket.socketMap_.find(transport) != socket.socketMap_.end()) {
574             socket.clientNumber_--;
575             Socket *serverSocket = nullptr;
576             serverSocket = socket.socketMap_.at(transport).get();
577             serverSocket->state_ = DISCONNECTED;
578             serverSocket->newSockTransport_ = nullptr;
579             SocketThread::GetInstance().DeleteSocket(*serverSocket);
580             serverSocket->CloseSocketFd();
581             std::lock_guard<std::recursive_mutex> lk(Socket::g_socketMutex);
582             Socket::EraseSocket(*serverSocket);
583             socket.socketMap_.erase(transport);
584             socket.NotifyServiceDeleteSocket(*serverSocket);
585             if (transport != nullptr) {
586                 delete transport;
587             }
588         } else {
589             LOG_ERROR("[sock]socket does not exist");
590         }
591     } else {
592         socket.state_ = DISCONNECTED;
593         socketGapClient_->UnregisterSecurity(remoteAddr_, scn_, serviceId_);
594         FreeServiceId(serviceId_);
595         SocketThread::GetInstance().DeleteSocket(*this);
596         socket.CloseSocketFd();
597         std::lock_guard<std::recursive_mutex> lk(Socket::g_socketMutex);
598         Socket::EraseSocket(socket);
599         Socket::NotifyServiceDeleteSocket(socket);
600     }
601 }
602 
SetRemoteAddr(std::string addr)603 void Socket::SetRemoteAddr(std::string addr)
604 {
605     LOG_INFO("[sock]%{public}s", __func__);
606 
607     RawAddress rawAddr(addr);
608     rawAddr.ConvertToUint8(remoteAddr_.addr);
609     remoteAddr_.type = BT_PUBLIC_DEVICE_ADDRESS;
610 }
611 
CloseSocket(bool isDisable)612 void Socket::CloseSocket(bool isDisable)
613 {
614     LOG_INFO("[sock]%{public}s", __func__);
615 
616     CloseSocketFd();
617 
618     if (isServer_ && (!isNewSocket_)) {
619         RFCOMM_FreeServerNum(scn_);
620         sdpServer_->UnregisterSdpService();
621         socketGapServer_->UnregisterSecurity(remoteAddr_, scn_, serviceId_);
622         FreeServiceId(serviceId_);
623         if (isDisable) {
624             sockTransport_->RemoveServer(true);
625         } else {
626             sockTransport_->RemoveServer(false);
627         }
628         state_ = CLOSED;
629         NotifyServiceDeleteSocket(*this);
630         std::lock_guard<std::recursive_mutex> lk(Socket::g_socketMutex);
631         EraseSocket(*this);
632         return;
633     }
634 
635     if (state_ == CONNECTED || state_ == CONNECTING) {
636         LOG_INFO("[sock]%{public}s close connection", __func__);
637         if (isServer_) {
638             socketGapServer_->UnregisterSecurity(remoteAddr_, scn_, serviceId_);
639         } else {
640             socketGapClient_->UnregisterSecurity(remoteAddr_, scn_, serviceId_);
641         }
642         FreeServiceId(serviceId_);
643 
644         if (isServer_) {
645             if (newSockTransport_ != nullptr) {
646                 newSockTransport_->Disconnect();
647             } else {
648                 LOG_ERROR("[sock]%{public}s newSockTransport is null", __func__);
649             }
650         } else {
651             if (sockTransport_ != nullptr) {
652                 sockTransport_->Disconnect();
653             } else {
654                 LOG_ERROR("[sock]%{public}s client sockTransport is null", __func__);
655             }
656         }
657     } else if (state_ == INIT || state_ == DISCONNECTED) {
658         LOG_INFO("[sock]%{public}s close no connection", __func__);
659         NotifyServiceDeleteSocket(*this);
660     }
661 }
662 
OnSocketReadReady(Socket & sock)663 void Socket::OnSocketReadReady(Socket &sock)
664 {
665     std::lock_guard<std::recursive_mutex> lk(Socket::g_socketMutex);
666     std::vector<Socket *>::iterator it = find(g_allServerSockets.begin(), g_allServerSockets.end(), &sock);
667     if (it == g_allServerSockets.end()) {
668         LOG_DEBUG("[sock]%{public}s socket does not exist", __func__);
669         return;
670     }
671 
672     std::lock_guard<std::recursive_mutex> lck(sock.writeMutex_);
673     if (sock.isCanWrite_) {
674         LOG_INFO("[sock]%{public}s socket write data", __func__);
675         sock.WriteData();
676     }
677 }
678 
OnSocketWriteReady(Socket & sock)679 void Socket::OnSocketWriteReady(Socket &sock)
680 {
681     LOG_INFO("[sock]%{public}s", __func__);
682 
683     std::lock_guard<std::recursive_mutex> lk(Socket::g_socketMutex);
684     std::vector<Socket *>::iterator it = find(g_allServerSockets.begin(), g_allServerSockets.end(), &sock);
685     if (it == g_allServerSockets.end()) {
686         LOG_DEBUG("[sock]%{public}s socket does not exist", __func__);
687         return;
688     }
689 
690     SocketService *socketService =
691         static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
692     if (socketService != nullptr) {
693         socketService->GetDispatcher()->PostTask(std::bind(&Socket::OnSocketWriteReadyNative, &sock, std::ref(sock)));
694     }
695 }
696 
OnSocketWriteReadyNative(Socket & sock)697 void Socket::OnSocketWriteReadyNative(Socket &sock)
698 {
699     LOG_INFO("[sock]%{public}s", __func__);
700 
701     sock.isCanRead_ = true;
702 
703     if (sock.recvBufLen_ == 0) {
704         LOG_INFO("[sock]%{public}s recvbuf has been send", __func__);
705         sock.ReadData();
706         return;
707     }
708 
709     SocketSendRet sendRet = SendDataToApp(sock.transportFd_, sock.recvDataBuf_, sock.recvBufLen_);
710     switch (sendRet) {
711         case SOCKET_SEND_NONE:
712         case SOCKET_SEND_PARTIAL:
713             sock.isCanRead_ = false;
714             SocketThread::GetInstance().AddSocket(
715                 sock.transportFd_, 1, sock);
716             break;
717         case SOCKET_SEND_ERROR:
718             sock.isCanRead_ = false;
719             LOG_INFO("[sock]%{public}s close socket", __func__);
720             SocketThread::GetInstance().DeleteSocket(sock);
721             sock.CloseSocket(false);
722             break;
723         case SOCKET_SEND_ALL:
724             sock.isCanRead_ = true;
725             (void)memset_s(sock.recvDataBuf_, SOCK_DEF_RFC_MTU, 0, SOCK_DEF_RFC_MTU);
726             sock.ReadData();
727             LOG_INFO("[sock]%{public}s send data success", __func__);
728             break;
729         default:
730             break;
731     }
732 }
733 
ReadData()734 void Socket::ReadData()
735 {
736     Packet *pkt = nullptr;
737     uint8_t *pData = nullptr;
738     Buffer *buf = nullptr;
739 
740     while (true) {
741         if (!this->isCanRead_) {
742             LOG_DEBUG("[sock]%{public}s can not read.", __func__);
743             return;
744         }
745 
746         if (this->isNewSocket_) {
747             if (this->newSockTransport_ == nullptr) {
748                 LOG_DEBUG("[sock]%{public}s newSockTransport is null", __func__);
749                 return;
750             }
751             if (this->newSockTransport_->Read(&pkt) != 0) {
752                 break;
753             }
754         } else {
755             if (this->sockTransport_->Read(&pkt) != 0) {
756                 break;
757             }
758         }
759 
760         if (pkt == nullptr) {
761             LOG_ERROR("[sock]%{public}s pkt is null", __func__);
762             return;
763         }
764 
765         size_t len = PacketPayloadSize(pkt);
766         if (len == 0) {
767             break;
768         }
769         buf = PacketContinuousPayload(pkt);
770         if (buf != nullptr) {
771             pData = (uint8_t *)BufferPtr(buf);
772         }
773         if (pData == nullptr) {
774             return;
775         }
776 
777         this->WriteDataToAPP(pData, len);
778 
779         if (pkt != nullptr) {
780             PacketFree(pkt);
781             pkt = nullptr;
782         }
783     }
784 }
785 
WriteDataToAPP(const uint8_t * buffer,size_t len)786 void Socket::WriteDataToAPP(const uint8_t *buffer, size_t len)
787 {
788     LOG_INFO("[sock]%{public}s", __func__);
789 
790     SocketSendRet sendRet = SendDataToApp(this->transportFd_, buffer, len);
791     switch (sendRet) {
792         case SOCKET_SEND_NONE:
793         case SOCKET_SEND_PARTIAL:
794             LOG_INFO("[sock]%{public}s SOCKET_SEND_PARTIAL", __func__);
795             this->isCanRead_ = false;
796             (void)memcpy_s(this->recvDataBuf_, SOCK_DEF_RFC_MTU, buffer, len);
797             this->recvBufLen_ = len;
798             SocketThread::GetInstance().AddSocket(this->transportFd_, 1, *this);
799             break;
800         case SOCKET_SEND_ERROR:
801             this->isCanRead_ = false;
802             LOG_INFO("[sock]%{public}s send data error", __func__);
803             SocketThread::GetInstance().DeleteSocket(*this);
804             this->CloseSocket(false);
805             break;
806         case SOCKET_SEND_ALL:
807             this->isCanRead_ = true;
808             LOG_INFO("[sock]%{public}s send data success", __func__);
809             break;
810         default:
811             break;
812     }
813 }
814 
WriteData()815 void Socket::WriteData()
816 {
817     LOG_INFO("[sock]%{public}s", __func__);
818 
819     int totalSize = 0;
820 
821     {
822         std::lock_guard<std::mutex> lock(fdMutex_);
823         if (ioctl(this->transportFd_, FIONREAD, &totalSize) != 0) {
824             LOG_ERROR("[sock]%{public}s ioctl read fd error", __func__);
825             return;
826         }
827     }
828 
829     if (totalSize == 0) {
830         LOG_DEBUG("[sock]%{public}s recv buffer has no data", __func__);
831         return;
832     }
833 
834     LOG_INFO("[sock]%{public}s totalSize:%{public}d", __func__, totalSize);
835 
836     while (totalSize > 0) {
837         if (this->isCanWrite_) {
838             int mallocSize = (totalSize > this->sendMTU_) ? this->sendMTU_ : totalSize;
839 
840             Packet *wPkt = PacketMalloc(0, 0, mallocSize);
841             if (wPkt == nullptr) {
842                 LOG_INFO("[sock]pkt is null");
843                 return;
844             }
845             Buffer *wPayloadBuf = PacketContinuousPayload(wPkt);
846             void *buffer = BufferPtr(wPayloadBuf);
847 
848             int wbytes = read(this->transportFd_, buffer, mallocSize);
849             LOG_INFO("[sock]%{public}s wbytes:%{public}d", __func__, wbytes);
850             if (wbytes <= 0) {
851                 LOG_DEBUG("[sock]%{public}s socket fd exception", __func__);
852                 PacketFree(wPkt);
853                 return;
854             }
855             int ret = TransportWrite(wPkt);
856             if (ret < 0) {
857                 LOG_DEBUG("[sock]%{public}s stack write failed", __func__);
858                 (void)memcpy_s(this->sendDataBuf_, wbytes, buffer, wbytes);
859                 this->sendBufLen_ = wbytes;
860                 this->isCanWrite_ = false;
861                 PacketFree(wPkt);
862                 return;
863             }
864             totalSize -= wbytes;
865             PacketFree(wPkt);
866         } else {
867             return;
868         }
869     }
870 }
871 
TransportWrite(Packet * subPkt)872 int Socket::TransportWrite(Packet *subPkt)
873 {
874     LOG_INFO("[sock]%{public}s", __func__);
875 
876     RawAddress rawAddr = RawAddress::ConvertToString(this->remoteAddr_.addr);
877     IPowerManager::GetInstance().StatusUpdate(RequestStatus::BUSY, PROFILE_NAME_SPP, rawAddr);
878 
879     int ret = 0;
880     if (this->isNewSocket_) {
881         if (this->newSockTransport_ == nullptr) {
882             LOG_DEBUG("[sock]%{public}s newSockTransport is nullptr", __func__);
883         } else {
884             ret = this->newSockTransport_->Write(subPkt);
885         }
886     } else {
887         ret = this->sockTransport_->Write(subPkt);
888     }
889     IPowerManager::GetInstance().StatusUpdate(RequestStatus::IDLE, PROFILE_NAME_SPP, rawAddr);
890     return ret;
891 }
892 
OnSocketException(Socket & sock)893 void Socket::OnSocketException(Socket &sock)
894 {
895     LOG_INFO("[sock]%{public}s", __func__);
896 
897     SocketService *socketService =
898         static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
899     if (socketService != nullptr) {
900         socketService->GetDispatcher()->PostTask(std::bind(&Socket::OnSocketExceptionNative, &sock, std::ref(sock)));
901     }
902 }
903 
OnSocketExceptionNative(Socket & sock)904 void Socket::OnSocketExceptionNative(Socket &sock)
905 {
906     LOG_INFO("[sock]%{public}s", __func__);
907 
908     std::lock_guard<std::recursive_mutex> lk(Socket::g_socketMutex);
909     std::vector<Socket *>::iterator it;
910 
911     LOG_INFO("[sock]%{public}s size:%{public}zu", __func__, g_allServerSockets.size());
912 
913     for (it = g_allServerSockets.begin(); it != g_allServerSockets.end(); ++it) {
914         if (*it == &sock) {
915             sock.CloseSocket(false);
916             break;
917         }
918     }
919 }
920 
AssignServiceId()921 GAP_Service Socket::AssignServiceId()
922 {
923     int serviceId = 0;
924     for (int i = 0; i < SOCK_MAX_SERVICE_ID; i++) {
925         if (g_arrayServiceId[i] == 0) {
926             g_arrayServiceId[i] = SPP_ID_START + i;
927             serviceId = g_arrayServiceId[i];
928             break;
929         }
930     }
931     return (GAP_Service)serviceId;
932 }
933 
FreeServiceId(GAP_Service serviceId)934 void Socket::FreeServiceId(GAP_Service serviceId)
935 {
936     if (serviceId >= SPP_ID_START) {
937         g_arrayServiceId[serviceId - SPP_ID_START] = 0;
938     }
939 }
940 
SendDataToApp(int fd,const uint8_t * buf,size_t len)941 SocketSendRet Socket::SendDataToApp(int fd, const uint8_t *buf, size_t len)
942 {
943     LOG_INFO("[sock]%{public}s", __func__);
944 
945 #ifdef DARWIN_PLATFORM
946     auto sendRet = send(fd, buf, len, MSG_DONTWAIT);
947 #else
948     auto sendRet = send(fd, buf, len, MSG_NOSIGNAL);
949 #endif
950     if (sendRet < 0) {
951         if ((errno == EAGAIN || errno == EWOULDBLOCK)) {
952             return SOCKET_SEND_NONE;
953         }
954         return SOCKET_SEND_ERROR;
955     }
956 
957     if (sendRet == 0) {
958         return SOCKET_SEND_ERROR;
959     }
960 
961     if (sendRet == ssize_t(len)) {
962         return SOCKET_SEND_ALL;
963     }
964 
965     return SOCKET_SEND_PARTIAL;
966 }
967 
NotifyServiceDeleteSocket(Socket & sock)968 void Socket::NotifyServiceDeleteSocket(Socket &sock)
969 {
970     LOG_INFO("[sock]%{public}s", __func__);
971 
972     utility::Message msg(SOCKET_CLOSE);
973     msg.arg2_ = &sock;
974     SocketService *socketService =
975         static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
976     if (socketService != nullptr) {
977         socketService->ProcessMessage(msg);
978     }
979 }
980 
EraseSocket(Socket & socket)981 void Socket::EraseSocket(Socket &socket)
982 {
983     LOG_INFO("[sock]%{public}s", __func__);
984 
985     std::vector<Socket *>::iterator it;
986     LOG_INFO("[sock]%{public}s size:%{public}zu", __func__, g_allServerSockets.size());
987 
988     for (it = g_allServerSockets.begin(); it != g_allServerSockets.end(); ++it) {
989         if (*it == &socket) {
990             g_allServerSockets.erase(it);
991             break;
992         }
993     }
994 }
995 
RemoveServerSocket()996 void Socket::RemoveServerSocket()
997 {
998     LOG_INFO("[sock]%{public}s", __func__);
999 
1000     CloseSocketFd();
1001     EraseSocket(*this);
1002     sockTransport_->RemoveServer(true);
1003 }
1004 
CloseSocketFd()1005 void Socket::CloseSocketFd()
1006 {
1007     LOG_INFO("[sock]%{public}s", __func__);
1008     if (this->transportFd_ != SOCK_INVALID_FD) {
1009         LOG_DEBUG("closefd : transportFd_:%{public}d", this->transportFd_);
1010         shutdown(this->transportFd_, SHUT_RDWR);
1011         close(this->transportFd_);
1012         std::lock_guard<std::mutex> lock(this->fdMutex_);
1013         this->transportFd_ = SOCK_INVALID_FD;
1014     }
1015 
1016     if (this->upperlayerFd_ != SOCK_INVALID_FD) {
1017         LOG_DEBUG("closefd : upperlayerFd_:%{public}d", this->upperlayerFd_);
1018         shutdown(this->upperlayerFd_, SHUT_RDWR);
1019         close(this->upperlayerFd_);
1020         this->upperlayerFd_ = SOCK_INVALID_FD;
1021     }
1022 }
1023 
GetDispatchter()1024 utility::Dispatcher *Socket::GetDispatchter()
1025 {
1026     LOG_INFO("[sock]%{public}s", __func__);
1027     SocketService *socketService =
1028         static_cast<SocketService *>(IProfileManager::GetInstance()->GetProfileService(PROFILE_NAME_SPP));
1029     if (socketService == nullptr) {
1030         return nullptr;
1031     }
1032     return socketService->GetDispatcher();
1033 }
1034 
ClearUpAllSocket()1035 void Socket::ClearUpAllSocket()
1036 {
1037     LOG_INFO("[sock]%{public}s", __func__);
1038 
1039     LOG_INFO("[sock]%{public}s size:%{public}zu", __func__, g_allServerSockets.size());
1040     if (g_allServerSockets.size() > 0) {
1041         g_allServerSockets.clear();
1042     }
1043 }
1044 }  // namespace bluetooth
1045 }  // namespace OHOS
1046