1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "soft_bus_server_socket.h"
17 
18 #include "remote_connect_listener_manager.h"
19 
20 #define LOG_TAG "USER_AUTH_SA"
21 namespace OHOS {
22 namespace UserIam {
23 namespace UserAuth {
ServerSocket(const int32_t socketId)24 ServerSocket::ServerSocket(const int32_t socketId)
25     : BaseSocket(socketId)
26 {
27     IAM_LOGI("server socket id is %{public}d.", socketId);
28 }
29 
SendMessage(const std::string & connectionName,const std::string & srcEndPoint,const std::string & destEndPoint,const std::shared_ptr<Attributes> & attributes,MsgCallback & callback)30 ResultCode ServerSocket::SendMessage(const std::string &connectionName, const std::string &srcEndPoint,
31     const std::string &destEndPoint, const std::shared_ptr<Attributes> &attributes, MsgCallback &callback)
32 {
33     IAM_LOGI("start.");
34     int32_t socketId = GetSocketIdByClientConnectionName(connectionName);
35     if (socketId == INVALID_SOCKET_ID) {
36         IAM_LOGE("socket id is invalid");
37         return GENERAL_ERROR;
38     }
39 
40     return SendRequest(socketId, connectionName, srcEndPoint, destEndPoint, attributes, callback);
41 }
42 
OnBind(int32_t socketId,PeerSocketInfo info)43 void ServerSocket::OnBind(int32_t socketId, PeerSocketInfo info)
44 {
45     IAM_LOGI("start, socket id is %{public}d", socketId);
46     if (socketId <= INVALID_SOCKET_ID) {
47         IAM_LOGE("socket id invalid.");
48         return;
49     }
50 
51     std::string peerNetworkId(info.networkId);
52     AddServerSocket(socketId, peerNetworkId);
53 }
54 
OnShutdown(int32_t socketId,ShutdownReason reason)55 void ServerSocket::OnShutdown(int32_t socketId, ShutdownReason reason)
56 {
57     IAM_LOGI("start, socket id is %{public}d", socketId);
58     std::string connectionName = GetClientConnectionName(socketId);
59     if (!connectionName.empty()) {
60         RemoteConnectListenerManager::GetInstance().OnConnectionDown(connectionName);
61     }
62     DeleteServerSocket(socketId);
63     DeleteClientConnection(socketId);
64 }
65 
OnBytes(int32_t socketId,const void * data,uint32_t dataLen)66 void ServerSocket::OnBytes(int32_t socketId, const void *data, uint32_t dataLen)
67 {
68     IAM_LOGI("start, socket id is %{public}d", socketId);
69     std::string networkId = GetNetworkIdBySocketId(socketId);
70     if (networkId.empty()) {
71         IAM_LOGE("networkId id is null, socketId:%{public}d.", socketId);
72         return;
73     }
74 
75     std::shared_ptr<SoftBusMessage> softBusMessage = ParseMessage(networkId, const_cast<void *>(data), dataLen);
76     if (softBusMessage == nullptr) {
77         IAM_LOGE("serverSocket parse message fail.");
78         return;
79     }
80 
81     bool ack = softBusMessage->GetAckFlag();
82     std::string connectionName = softBusMessage->GetConnectionName();
83     if (ack == false && !connectionName.empty()) {
84         AddClientConnection(socketId, connectionName);
85     }
86 
87     ResultCode ret = ProcDataReceive(socketId, softBusMessage);
88     if (ret != SUCCESS) {
89         IAM_LOGE("HandleDataReceive fail, socketId:%{public}d.", socketId);
90         return;
91     }
92 }
93 
OnQos(int32_t socketId,QoSEvent eventId,const QosTV * qos,uint32_t qosCount)94 void ServerSocket::OnQos(int32_t socketId, QoSEvent eventId, const QosTV *qos, uint32_t qosCount)
95 {
96     IAM_LOGI("start, socket id is %{public}d", socketId);
97 }
98 
AddServerSocket(const int32_t socketId,const std::string & networkId)99 void ServerSocket::AddServerSocket(const int32_t socketId, const std::string &networkId)
100 {
101     IAM_LOGI("start, socketId %{public}d.", socketId);
102     IF_FALSE_LOGE_AND_RETURN(socketId != INVALID_SOCKET_ID);
103 
104     std::lock_guard<std::recursive_mutex> lock(socketMutex_);
105     auto iter = serverSocketBindMap_.find(socketId);
106     if (iter == serverSocketBindMap_.end()) {
107         serverSocketBindMap_.insert(std::pair<int32_t, std::string>(socketId, networkId));
108     } else {
109         iter->second = networkId;
110     }
111 }
112 
DeleteServerSocket(const int32_t socketId)113 void ServerSocket::DeleteServerSocket(const int32_t socketId)
114 {
115     IAM_LOGI("start, socketId %{public}d.", socketId);
116     IF_FALSE_LOGE_AND_RETURN(socketId != INVALID_SOCKET_ID);
117 
118     std::lock_guard<std::recursive_mutex> lock(socketMutex_);
119     auto iter = serverSocketBindMap_.find(socketId);
120     if (iter != serverSocketBindMap_.end()) {
121         serverSocketBindMap_.erase(iter);
122     }
123 }
124 
GetNetworkIdBySocketId(int32_t socketId)125 std::string ServerSocket::GetNetworkIdBySocketId(int32_t socketId)
126 {
127     IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, "");
128 
129     std::lock_guard<std::recursive_mutex> lock(socketMutex_);
130     std::string networkId;
131     auto iter = serverSocketBindMap_.find(socketId);
132     if (iter != serverSocketBindMap_.end()) {
133         networkId = iter->second;
134     }
135     return networkId;
136 }
137 
AddClientConnection(const int32_t socketId,const std::string & connectionName)138 void ServerSocket::AddClientConnection(const int32_t socketId, const std::string &connectionName)
139 {
140     IAM_LOGI("add socketId %{public}d connectionName %{public}s.", socketId, connectionName.c_str());
141     IF_FALSE_LOGE_AND_RETURN(socketId != INVALID_SOCKET_ID);
142 
143     std::lock_guard<std::recursive_mutex> lock(connectionMutex_);
144     auto iter = clientConnectionMap_.find(socketId);
145     if (iter == clientConnectionMap_.end()) {
146         clientConnectionMap_.insert(std::pair<int32_t, std::string>(socketId, connectionName));
147     }
148 }
149 
DeleteClientConnection(const int32_t socketId)150 void ServerSocket::DeleteClientConnection(const int32_t socketId)
151 {
152     IAM_LOGI("start, socketId %{public}d.", socketId);
153     IF_FALSE_LOGE_AND_RETURN(socketId != INVALID_SOCKET_ID);
154 
155     std::lock_guard<std::recursive_mutex> lock(connectionMutex_);
156     auto iter = clientConnectionMap_.find(socketId);
157     if (iter != clientConnectionMap_.end()) {
158         std::string connectionName = iter->second;
159         IAM_LOGI("delete socketId %{public}d connectionName %{public}s.", socketId, connectionName.c_str());
160         clientConnectionMap_.erase(iter);
161     }
162 }
163 
GetClientConnectionName(const int32_t socketId)164 std::string ServerSocket::GetClientConnectionName(const int32_t socketId)
165 {
166     IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, "");
167 
168     std::lock_guard<std::recursive_mutex> lock(connectionMutex_);
169     std::string ConnectionName;
170     auto iter = clientConnectionMap_.find(socketId);
171     if (iter != clientConnectionMap_.end()) {
172         ConnectionName = iter->second;
173     }
174     return ConnectionName;
175 }
176 
GetSocketIdByClientConnectionName(const std::string & connectionName)177 int32_t ServerSocket::GetSocketIdByClientConnectionName(const std::string &connectionName)
178 {
179     std::lock_guard<std::recursive_mutex> lock(connectionMutex_);
180     int32_t socketId = INVALID_SOCKET_ID;
181     for (const auto &iter : clientConnectionMap_) {
182         if (iter.second == connectionName) {
183             socketId = iter.first;
184             break;
185         }
186     }
187 
188     return socketId;
189 }
190 
GetConnectionName()191 std::string ServerSocket::GetConnectionName()
192 {
193     return "";
194 }
195 
GetNetworkId()196 std::string ServerSocket::GetNetworkId()
197 {
198     return "";
199 }
200 } // namespace UserAuth
201 } // namespace UserIam
202 } // namespace OHOS