1 /*
2  * Copyright (C) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "databus_socket_listener.h"
17 
18 #include "dbinder_databus_invoker.h"
19 #include "ipc_debug.h"
20 #include "ipc_process_skeleton.h"
21 #include "ipc_thread_skeleton.h"
22 #include "log_tags.h"
23 #include "softbus_error_code.h"
24 
25 namespace OHOS {
26 static constexpr OHOS::HiviewDFX::HiLogLabel LABEL = { LOG_CORE, LOG_ID_RPC_REMOTE_LISTENER, "DatabusSocketListener" };
27 
DBinderSocketInfo(const std::string & ownName,const std::string & peerName,const std::string & networkId)28 DBinderSocketInfo::DBinderSocketInfo(const std::string &ownName, const std::string &peerName,
29     const std::string &networkId) : ownName_(ownName), peerName_(peerName), networkId_(networkId)
30 {}
31 
GetOwnName() const32 std::string DBinderSocketInfo::GetOwnName() const
33 {
34     return ownName_;
35 }
36 
GetPeerName() const37 std::string DBinderSocketInfo::GetPeerName() const
38 {
39     return peerName_;
40 }
41 
GetNetworkId() const42 std::string DBinderSocketInfo::GetNetworkId() const
43 {
44     return networkId_;
45 }
46 
DatabusSocketListener()47 DatabusSocketListener::DatabusSocketListener()
48 {
49     serverListener_.OnBind = DatabusSocketListener::ServerOnBind;
50     serverListener_.OnShutdown = DatabusSocketListener::ServerOnShutdown;
51     serverListener_.OnBytes = DatabusSocketListener::OnBytesReceived;
52     serverListener_.OnMessage = DatabusSocketListener::OnBytesReceived;
53 
54     clientListener_.OnBind = DatabusSocketListener::ClientOnBind;
55     clientListener_.OnShutdown = DatabusSocketListener::ClientOnShutdown;
56     clientListener_.OnBytes = DatabusSocketListener::OnBytesReceived;
57     clientListener_.OnMessage = DatabusSocketListener::OnBytesReceived;
58 }
59 
~DatabusSocketListener()60 DatabusSocketListener::~DatabusSocketListener() {}
61 
ServerOnBind(int32_t socket,PeerSocketInfo info)62 void DatabusSocketListener::ServerOnBind(int32_t socket, PeerSocketInfo info)
63 {
64     ZLOGI(LABEL, "socketId:%{public}d, deviceId:%{public}s, peerName:%{public}s",
65         socket, IPCProcessSkeleton::ConvertToSecureString(info.networkId).c_str(), info.name);
66 
67     std::string networkId = info.networkId;
68     std::string peerName = info.name;
69     std::string str = peerName.substr(DBINDER_SOCKET_NAME_PREFIX.length());
70     std::string::size_type pos = str.find("_");
71     std::string peerUid = str.substr(0, pos);
72     std::string peerPid = str.substr(pos + 1);
73 
74     DBinderDatabusInvoker *invoker =
75         reinterpret_cast<DBinderDatabusInvoker *>(IPCThreadSkeleton::GetRemoteInvoker(IRemoteObject::IF_PROT_DATABUS));
76     if (invoker == nullptr) {
77         ZLOGE(LABEL, "fail to get invoker");
78         return;
79     }
80 
81     invoker->OnReceiveNewConnection(socket, std::stoi(peerPid), std::stoi(peerUid), peerName, networkId);
82 }
83 
ServerOnShutdown(int32_t socket,ShutdownReason reason)84 void DatabusSocketListener::ServerOnShutdown(int32_t socket, ShutdownReason reason)
85 {
86     ZLOGI(LABEL, "socketId:%{public}d, ShutdownReason:%{public}d", socket, reason);
87     DBinderDatabusInvoker *invoker =
88         reinterpret_cast<DBinderDatabusInvoker *>(IPCThreadSkeleton::GetRemoteInvoker(IRemoteObject::IF_PROT_DATABUS));
89     if (invoker == nullptr) {
90         ZLOGE(LABEL, "fail to get invoker");
91         return;
92     }
93     invoker->OnDatabusSessionServerSideClosed(socket);
94 }
95 
ClientOnBind(int32_t socket,PeerSocketInfo info)96 void DatabusSocketListener::ClientOnBind(int32_t socket, PeerSocketInfo info)
97 {
98     return;
99 }
100 
ClientOnShutdown(int32_t socket,ShutdownReason reason)101 void DatabusSocketListener::ClientOnShutdown(int32_t socket, ShutdownReason reason)
102 {
103     ZLOGI(LABEL, "socketId:%{public}d, ShutdownReason:%{public}d", socket, reason);
104     DBinderDatabusInvoker *invoker =
105         reinterpret_cast<DBinderDatabusInvoker *>(IPCThreadSkeleton::GetRemoteInvoker(IRemoteObject::IF_PROT_DATABUS));
106     if (invoker == nullptr) {
107         ZLOGE(LABEL, "fail to get invoker");
108         return;
109     }
110 
111     DBinderSocketInfo socketInfo;
112     {
113         std::lock_guard<std::mutex> lockGuard(socketInfoMutex_);
114         for (auto it = socketInfoMap_.begin(); it != socketInfoMap_.end(); it++) {
115             if (it->second == socket) {
116                 socketInfo = it->first;
117                 ZLOGI(LOG_LABEL, "erase socketId:%{public}d ", it->second);
118                 socketInfoMap_.erase(it);
119                 break;
120             }
121         }
122     }
123     EraseDeviceLock(socketInfo);
124     invoker->OnDatabusSessionClientSideClosed(socket);
125 }
126 
OnBytesReceived(int32_t socket,const void * data,uint32_t dataLen)127 void DatabusSocketListener::OnBytesReceived(int32_t socket, const void *data, uint32_t dataLen)
128 {
129     ZLOGI(LABEL, "socketId:%{public}d len:%{public}u", socket, dataLen);
130     DBinderDatabusInvoker *invoker =
131         reinterpret_cast<DBinderDatabusInvoker *>(IPCThreadSkeleton::GetRemoteInvoker(IRemoteObject::IF_PROT_DATABUS));
132     if (invoker == nullptr) {
133         ZLOGE(LABEL, "fail to get invoker");
134         return;
135     }
136 
137     invoker->OnMessageAvailable(socket, static_cast<const char*>(data), dataLen);
138 }
139 
StartServerListener(const std::string & ownName)140 int32_t DatabusSocketListener::StartServerListener(const std::string &ownName)
141 {
142     std::string pkgName = DBINDER_PKG_NAME + "_" + std::to_string(getpid());
143 
144     SocketInfo serverSocketInfo = {
145         .name = const_cast<char*>(ownName.c_str()),
146         .pkgName = const_cast<char*>(pkgName.c_str()),
147         .dataType = TransDataType::DATA_TYPE_BYTES,
148     };
149     int32_t socketId = DBinderSoftbusClient::GetInstance().Socket(serverSocketInfo);
150     if (socketId <= 0) {
151         ZLOGE(LABEL, "create socket server error, socket is invalid");
152         return SOCKET_ID_INVALID;
153     }
154     int32_t ret = DBinderSoftbusClient::GetInstance().Listen(socketId, QOS_TV, QOS_COUNT, &serverListener_);
155     if (ret != SOFTBUS_OK && ret != SOFTBUS_TRANS_SOCKET_IN_USE) {
156         ZLOGE(LABEL, "Listen failed, ret:%{public}d", ret);
157         DBinderSoftbusClient::GetInstance().Shutdown(socketId);
158         return SOCKET_ID_INVALID;
159     }
160     ZLOGI(LABEL, "Listen ok, socketId:%{public}d, ownName:%{public}s", socketId, ownName.c_str());
161     return socketId;
162 }
163 
QueryOrNewInfoMutex(DBinderSocketInfo socketInfo)164 std::shared_ptr<std::mutex> DatabusSocketListener::QueryOrNewInfoMutex(DBinderSocketInfo socketInfo)
165 {
166     std::lock_guard<std::mutex> lockGuard(deviceMutex_);
167     auto it = infoMutexMap_.find(socketInfo);
168     if (it != infoMutexMap_.end()) {
169         return it->second;
170     }
171     std::shared_ptr<std::mutex> infoMutex = std::make_shared<std::mutex>();
172     if (infoMutex == nullptr) {
173         ZLOGE(LOG_LABEL, "failed to create mutex, ownName:%{public}s, peerName:%{public}s, networkId:%{public}s",
174             socketInfo.GetOwnName().c_str(), socketInfo.GetPeerName().c_str(),
175             IPCProcessSkeleton::ConvertToSecureString(socketInfo.GetNetworkId()).c_str());
176         return nullptr;
177     }
178     infoMutexMap_[socketInfo] = infoMutex;
179     return infoMutex;
180 }
181 
CreateClientSocket(const std::string & ownName,const std::string & peerName,const std::string & networkId)182 int32_t DatabusSocketListener::CreateClientSocket(const std::string &ownName, const std::string &peerName,
183     const std::string &networkId)
184 {
185     DBinderSocketInfo info(ownName, peerName, networkId);
186     std::shared_ptr<std::mutex> infoMutex = QueryOrNewInfoMutex(info);
187     if (infoMutex == nullptr) {
188         return SOCKET_ID_INVALID;
189     }
190     std::lock_guard<std::mutex> lockUnique(*infoMutex);
191 
192     {
193         std::lock_guard<std::mutex> lockGuard(socketInfoMutex_);
194         auto it = socketInfoMap_.find(info);
195         if (it != socketInfoMap_.end()) {
196             return it->second;
197         }
198     }
199 
200     std::string pkgName = std::string(DBINDER_PKG_NAME) + "_" + std::to_string(getpid());
201     SocketInfo socketInfo = {
202         .name =  const_cast<char*>(ownName.c_str()),
203         .peerName = const_cast<char*>(peerName.c_str()),
204         .peerNetworkId = const_cast<char*>(networkId.c_str()),
205         .pkgName = const_cast<char*>(pkgName.c_str()),
206         .dataType = TransDataType::DATA_TYPE_BYTES,
207     };
208     int32_t socketId = DBinderSoftbusClient::GetInstance().Socket(socketInfo);
209     if (socketId <= 0) {
210         ZLOGE(LABEL, "create socket error, socket is invalid");
211         return SOCKET_ID_INVALID;
212     }
213     int32_t ret = DBinderSoftbusClient::GetInstance().Bind(socketId, QOS_TV, QOS_COUNT, &clientListener_);
214     if (ret != SOFTBUS_OK && ret != SOFTBUS_TRANS_SOCKET_IN_USE) {
215         ZLOGE(LABEL, "Bind failed, ret:%{public}d, socketId:%{public}d,"
216             "ownName:%{public}s, peerName:%{public}s, peerNetworkId:%{public}s",
217             ret, socketId, ownName.c_str(), peerName.c_str(),
218             IPCProcessSkeleton::ConvertToSecureString(networkId).c_str());
219         DBinderSoftbusClient::GetInstance().Shutdown(socketId);
220         EraseDeviceLock(info);
221         return SOCKET_ID_INVALID;
222     }
223     ZLOGI(LABEL, "Bind succ, ownName:%{public}s peer:%{public}s deviceId:%{public}s "
224         "socketId:%{public}d", ownName.c_str(), peerName.c_str(),
225         IPCProcessSkeleton::ConvertToSecureString(networkId).c_str(), socketId);
226     {
227         std::lock_guard<std::mutex> lockGuard(socketInfoMutex_);
228         socketInfoMap_[info] = socketId;
229     }
230     return socketId;
231 }
232 
ShutdownSocket(int32_t socketId)233 void DatabusSocketListener::ShutdownSocket(int32_t socketId)
234 {
235     DBinderSocketInfo socketInfo;
236     {
237         std::lock_guard<std::mutex> lockGuard(socketInfoMutex_);
238         for (auto it = socketInfoMap_.begin(); it != socketInfoMap_.end(); it++) {
239             if (it->second == socketId) {
240                 ZLOGI(LOG_LABEL, "Shutdown socketId:%{public}d ", it->second);
241                 DBinderSoftbusClient::GetInstance().Shutdown(it->second);
242                 socketInfo = it->first;
243                 it = socketInfoMap_.erase(it);
244                 break;
245             }
246         }
247     }
248     EraseDeviceLock(socketInfo);
249 }
250 
EraseDeviceLock(DBinderSocketInfo info)251 void DatabusSocketListener::EraseDeviceLock(DBinderSocketInfo info)
252 {
253     std::lock_guard<std::mutex> lockGuard(deviceMutex_);
254     auto it = infoMutexMap_.find(info);
255     if (it != infoMutexMap_.end()) {
256         infoMutexMap_.erase(it);
257     }
258 }
259 
RemoveSessionName(void)260 void DatabusSocketListener::RemoveSessionName(void)
261 {
262     IPCProcessSkeleton *current = IPCProcessSkeleton::GetCurrent();
263     if (current == nullptr) {
264         ZLOGE(LABEL, "get current is null");
265         return;
266     }
267     sptr<IRemoteObject> object = current->GetSAMgrObject();
268     if (object == nullptr) {
269         ZLOGE(LABEL, "get object is null");
270         return;
271     }
272 
273     IPCObjectProxy *samgr = reinterpret_cast<IPCObjectProxy *>(object.GetRefPtr());
274     const std::string sessionName = current->GetDatabusName();
275     samgr->RemoveSessionName(sessionName);
276     ZLOGI(LABEL, "%{public}s", sessionName.c_str());
277 }
278 } // namespace OHOS
279