1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "remote_auth_context.h"
17
18 #include "iam_check.h"
19 #include "iam_logger.h"
20 #include "iam_para2str.h"
21 #include "iam_ptr.h"
22
23 #include "device_manager_util.h"
24 #include "relative_timer.h"
25 #include "remote_msg_util.h"
26 #include "resource_node_utils.h"
27 #include "thread_handler.h"
28 #include "thread_handler_manager.h"
29
30 #define LOG_TAG "USER_AUTH_SA"
31
32 namespace OHOS {
33 namespace UserIam {
34 namespace UserAuth {
35 namespace {
36 constexpr uint32_t TIME_OUT_MS = 3 * 60 * 1000; // 3min
37 }
38 class RemoteAuthContextMessageCallback : public ConnectionListener, public NoCopyable {
39 public:
RemoteAuthContextMessageCallback(std::weak_ptr<BaseContext> callbackWeakBase,RemoteAuthContext * callback)40 RemoteAuthContextMessageCallback(std::weak_ptr<BaseContext> callbackWeakBase, RemoteAuthContext *callback)
41 : callbackWeakBase_(callbackWeakBase),
42 callback_(callback),
43 threadHandler_(ThreadHandler::GetSingleThreadInstance())
44 {
45 }
46
47 ~RemoteAuthContextMessageCallback() = default;
48
OnMessage(const std::string & connectionName,const std::string & srcEndPoint,const std::shared_ptr<Attributes> & request,std::shared_ptr<Attributes> & reply)49 void OnMessage(const std::string &connectionName, const std::string &srcEndPoint,
50 const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply) override
51 {
52 IF_FALSE_LOGE_AND_RETURN(request != nullptr);
53 IF_FALSE_LOGE_AND_RETURN(reply != nullptr);
54
55 IAM_LOGI("connectionName: %{public}s, srcEndPoint: %{public}s", connectionName.c_str(), srcEndPoint.c_str());
56 }
57
OnConnectStatus(const std::string & connectionName,ConnectStatus connectStatus)58 void OnConnectStatus(const std::string &connectionName, ConnectStatus connectStatus) override
59 {
60 IAM_LOGI("connectionName: %{public}s, connectStatus %{public}d", connectionName.c_str(), connectStatus);
61
62 IF_FALSE_LOGE_AND_RETURN(threadHandler_ != nullptr);
63 threadHandler_->PostTask(
64 [connectionName, connectStatus, callbackWeakBase = callbackWeakBase_, callback = callback_, this]() {
65 IAM_LOGI("OnConnectStatus process begin");
66 auto callbackSharedBase = callbackWeakBase.lock();
67 IF_FALSE_LOGE_AND_RETURN(callbackSharedBase != nullptr);
68
69 IF_FALSE_LOGE_AND_RETURN(callback != nullptr);
70 callback->OnConnectStatus(connectionName, connectStatus);
71 IAM_LOGI("OnConnectStatus process success");
72 });
73 IAM_LOGI("task posted");
74 }
75
76 private:
77 std::weak_ptr<BaseContext> callbackWeakBase_;
78 RemoteAuthContext *callback_ = nullptr;
79 std::shared_ptr<ThreadHandler> threadHandler_ = nullptr;
80 };
81
RemoteAuthContext(uint64_t contextId,std::shared_ptr<Authentication> auth,RemoteAuthContextParam & param,std::shared_ptr<ContextCallback> callback)82 RemoteAuthContext::RemoteAuthContext(uint64_t contextId, std::shared_ptr<Authentication> auth,
83 RemoteAuthContextParam ¶m, std::shared_ptr<ContextCallback> callback)
84 : SimpleAuthContext("RemoteAuthContext", contextId, auth, callback),
85 authType_(param.authType),
86 connectionName_(param.connectionName),
87 collectorNetworkId_(param.collectorNetworkId),
88 executorInfoMsg_(param.executorInfoMsg)
89 {
90 endPointName_ = REMOTE_AUTH_CONTEXT_ENDPOINT_NAME;
91 needSetupConnection_ = (executorInfoMsg_.size() == 0);
92 if (needSetupConnection_) {
93 ThreadHandlerManager::GetInstance().CreateThreadHandler(connectionName_);
94 }
95 }
96
~RemoteAuthContext()97 RemoteAuthContext::~RemoteAuthContext()
98 {
99 std::lock_guard<std::recursive_mutex> lock(mutex_);
100 if (cancelTimerId_.has_value()) {
101 RelativeTimer::GetInstance().Unregister(cancelTimerId_.value());
102 }
103 RemoteConnectionManager::GetInstance().UnregisterConnectionListener(connectionName_, endPointName_);
104 if (needSetupConnection_) {
105 RemoteConnectionManager::GetInstance().CloseConnection(connectionName_);
106 ThreadHandlerManager::GetInstance().DestroyThreadHandler(connectionName_);
107 }
108 IAM_LOGI("%{public}s destroy", GetDescription());
109 }
110
GetContextType() const111 ContextType RemoteAuthContext::GetContextType() const
112 {
113 return REMOTE_AUTH_CONTEXT;
114 }
115
SetExecutorInfoMsg(std::vector<uint8_t> msg)116 void RemoteAuthContext::SetExecutorInfoMsg(std::vector<uint8_t> msg)
117 {
118 std::lock_guard<std::recursive_mutex> lock(mutex_);
119
120 executorInfoMsg_ = msg;
121 IAM_LOGI("%{public}s executorInfoMsg_ size is %{public}zu", GetDescription(), executorInfoMsg_.size());
122 }
123
OnStart()124 bool RemoteAuthContext::OnStart()
125 {
126 std::lock_guard<std::recursive_mutex> lock(mutex_);
127 IAM_LOGI("%{public}s start", GetDescription());
128
129 cancelTimerId_ = RelativeTimer::GetInstance().Register(
130 [weakThis = weak_from_this(), this]() {
131 auto sharedThis = weakThis.lock();
132 IF_FALSE_LOGE_AND_RETURN(sharedThis != nullptr);
133 OnTimeOut();
134 },
135 TIME_OUT_MS);
136
137 if (needSetupConnection_) {
138 IAM_LOGI("%{public}s SetupConnection", GetDescription());
139 return SetupConnection();
140 }
141
142 IAM_LOGI("%{public}s StartAuth", GetDescription());
143 return StartAuth();
144 }
145
StartAuth()146 bool RemoteAuthContext::StartAuth()
147 {
148 std::lock_guard<std::recursive_mutex> lock(mutex_);
149 IAM_LOGI("%{public}s start remote auth", GetDescription());
150
151 IF_FALSE_LOGE_AND_RETURN_VAL(executorInfoMsg_.size() > 0, false);
152
153 std::vector<ExecutorInfo> executorInfos;
154 bool decodeRet = RemoteMsgUtil::DecodeQueryExecutorInfoReply(Attributes(executorInfoMsg_), executorInfos);
155 IF_FALSE_LOGE_AND_RETURN_VAL(decodeRet, false);
156 IF_FALSE_LOGE_AND_RETURN_VAL(executorInfos.size() > 0, false);
157
158 remoteExecutorProxy_ = Common::MakeShared<RemoteExecutorProxy>(connectionName_, executorInfos[0]);
159 IF_FALSE_LOGE_AND_RETURN_VAL(remoteExecutorProxy_ != nullptr, false);
160
161 ResultCode startExecutorRet = remoteExecutorProxy_->Start();
162 IF_FALSE_LOGE_AND_RETURN_VAL(startExecutorRet == SUCCESS, false);
163
164 std::string collectorUdid;
165 bool getCollectorUdidRet = DeviceManagerUtil::GetInstance().GetUdidByNetworkId(collectorNetworkId_, collectorUdid);
166 IF_FALSE_LOGE_AND_RETURN_VAL(getCollectorUdidRet, false);
167
168 IF_FALSE_LOGE_AND_RETURN_VAL(auth_ != nullptr, false);
169 auth_->SetCollectorUdid(collectorUdid);
170
171 bool startAuthRet = SimpleAuthContext::OnStart();
172 IF_FALSE_LOGE_AND_RETURN_VAL(startAuthRet, false);
173 IF_FALSE_LOGE_AND_RETURN_VAL(scheduleList_.size() == 1, false);
174 IF_FALSE_LOGE_AND_RETURN_VAL(scheduleList_[0] != nullptr, false);
175
176 IAM_LOGI("%{public}s start remote auth success, connectionName:%{public}s, scheduleId:%{public}s",
177 GetDescription(), connectionName_.c_str(), GET_MASKED_STRING(scheduleList_[0]->GetScheduleId()).c_str());
178 return true;
179 }
180
StartAuthDelayed()181 void RemoteAuthContext::StartAuthDelayed()
182 {
183 std::lock_guard<std::recursive_mutex> lock(mutex_);
184 IF_FALSE_LOGE_AND_RETURN(callback_ != nullptr);
185 IAM_LOGI("%{public}s start", GetDescription());
186
187 bool ret = StartAuth();
188 if (!ret) {
189 IAM_LOGE("%{public}s StartAuth failed, latest error %{public}d", GetDescription(), GetLatestError());
190 Attributes attr;
191 callback_->OnResult(GetLatestError(), attr);
192 return;
193 }
194 IAM_LOGI("%{public}s success", GetDescription());
195 }
196
SendQueryExecutorInfoMsg()197 bool RemoteAuthContext::SendQueryExecutorInfoMsg()
198 {
199 std::lock_guard<std::recursive_mutex> lock(mutex_);
200 IAM_LOGI("%{public}s start", GetDescription());
201
202 std::shared_ptr<Attributes> request = Common::MakeShared<Attributes>();
203 IF_FALSE_LOGE_AND_RETURN_VAL(request != nullptr, false);
204
205 bool setMsgTypeRet = request->SetInt32Value(Attributes::ATTR_MSG_TYPE, QUERY_EXECUTOR_INFO);
206 IF_FALSE_LOGE_AND_RETURN_VAL(setMsgTypeRet, false);
207
208 std::vector<int32_t> authTypes = { authType_ };
209 bool setAuthTypesRet = request->SetInt32ArrayValue(Attributes::ATTR_AUTH_TYPES, authTypes);
210 IF_FALSE_LOGE_AND_RETURN_VAL(setAuthTypesRet, false);
211
212 bool setExecutorRoleRet = request->SetInt32Value(Attributes::ATTR_EXECUTOR_ROLE, COLLECTOR);
213 IF_FALSE_LOGE_AND_RETURN_VAL(setExecutorRoleRet, false);
214
215 std::string localUdid;
216 bool getLocalUdidRet = DeviceManagerUtil::GetInstance().GetLocalDeviceUdid(localUdid);
217 IF_FALSE_LOGE_AND_RETURN_VAL(getLocalUdidRet, false);
218
219 MsgCallback msgCallback = [weakThis = weak_from_this(), this](const std::shared_ptr<Attributes> &reply) {
220 IF_FALSE_LOGE_AND_RETURN(reply != nullptr);
221
222 auto sharedThis = weakThis.lock();
223 IF_FALSE_LOGE_AND_RETURN(sharedThis != nullptr);
224
225 int32_t resultCode;
226 bool getResultCodeRet = reply->GetInt32Value(Attributes::ATTR_RESULT_CODE, resultCode);
227 IF_FALSE_LOGE_AND_RETURN(getResultCodeRet);
228
229 if (resultCode != SUCCESS) {
230 IAM_LOGE("%{public}s query executor info failed", GetDescription());
231 Attributes attr;
232 callback_->OnResult(GENERAL_ERROR, attr);
233 return;
234 }
235
236 SetExecutorInfoMsg(reply->Serialize());
237
238 auto handler = ThreadHandler::GetSingleThreadInstance();
239 IF_FALSE_LOGE_AND_RETURN(handler != nullptr);
240 handler->PostTask([weakThis = weak_from_this(), this]() {
241 auto sharedThis = weakThis.lock();
242 IF_FALSE_LOGE_AND_RETURN(sharedThis != nullptr);
243 StartAuthDelayed();
244 });
245 IAM_LOGI("%{public}s query executor info success", GetDescription());
246 };
247
248 ResultCode sendMsgRet = RemoteConnectionManager::GetInstance().SendMessage(connectionName_, endPointName_,
249 REMOTE_SERVICE_ENDPOINT_NAME, request, msgCallback);
250 IF_FALSE_LOGE_AND_RETURN_VAL(sendMsgRet == SUCCESS, false);
251
252 IAM_LOGI("%{public}s success", GetDescription());
253 return true;
254 }
255
SetupConnection()256 bool RemoteAuthContext::SetupConnection()
257 {
258 std::lock_guard<std::recursive_mutex> lock(mutex_);
259 IAM_LOGI("%{public}s start", GetDescription());
260
261 std::shared_ptr<RemoteAuthContextMessageCallback> callback =
262 Common::MakeShared<RemoteAuthContextMessageCallback>(shared_from_this(), this);
263 IF_FALSE_LOGE_AND_RETURN_VAL(callback != nullptr, false);
264
265 ResultCode registerResult =
266 RemoteConnectionManager::GetInstance().RegisterConnectionListener(connectionName_, endPointName_, callback);
267 IF_FALSE_LOGE_AND_RETURN_VAL(registerResult == SUCCESS, false);
268
269 ResultCode connectResult =
270 RemoteConnectionManager::GetInstance().OpenConnection(connectionName_, collectorNetworkId_, GetTokenId());
271 IF_FALSE_LOGE_AND_RETURN_VAL(connectResult == SUCCESS, false);
272
273 IAM_LOGI("%{public}s success", GetDescription());
274 return true;
275 }
276
OnConnectStatus(const std::string & connectionName,ConnectStatus connectStatus)277 void RemoteAuthContext::OnConnectStatus(const std::string &connectionName, ConnectStatus connectStatus)
278 {
279 std::lock_guard<std::recursive_mutex> lock(mutex_);
280
281 IF_FALSE_LOGE_AND_RETURN(connectionName_ == connectionName);
282 IF_FALSE_LOGE_AND_RETURN(callback_ != nullptr);
283
284 Attributes attr;
285 if (connectStatus == ConnectStatus::DISCONNECTED) {
286 IAM_LOGI("%{public}s connection is disconnected", GetDescription());
287 callback_->OnResult(ResultCode::GENERAL_ERROR, attr);
288 return;
289 } else {
290 IAM_LOGI("%{public}s connection is connected", GetDescription());
291 bool sendMsgRet = SendQueryExecutorInfoMsg();
292 if (!sendMsgRet) {
293 IAM_LOGE("%{public}s SendQueryExecutorInfoMsg failed", GetDescription());
294 callback_->OnResult(GENERAL_ERROR, attr);
295 return;
296 }
297 IAM_LOGI("%{public}s connection is connected processed", GetDescription());
298 }
299 }
300
OnTimeOut()301 void RemoteAuthContext::OnTimeOut()
302 {
303 IAM_LOGI("%{public}s timeout", GetDescription());
304 IF_FALSE_LOGE_AND_RETURN(callback_ != nullptr);
305
306 Attributes attr;
307 callback_->OnResult(TIMEOUT, attr);
308 }
309 } // namespace UserAuth
310 } // namespace UserIam
311 } // namespace OHOS
312