1 /*
2  * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "connection_state_item.h"
17 
18 #include "hilog_tag_wrapper.h"
19 
20 namespace OHOS {
21 namespace AAFwk {
22 /**
23  * @class ConnectedExtension
24  * ConnectedExtension,This class is used to record a connected extension.
25  */
26 class ConnectedExtension : public std::enable_shared_from_this<ConnectedExtension> {
27 public:
CreateConnectedExtension(std::shared_ptr<ConnectionRecord> record)28     static std::shared_ptr<ConnectedExtension> CreateConnectedExtension(std::shared_ptr<ConnectionRecord> record)
29     {
30         if (!record) {
31             return nullptr;
32         }
33 
34         auto targetExtension = record->GetAbilityRecord();
35         if (!targetExtension) {
36             return nullptr;
37         }
38 
39         return std::make_shared<ConnectedExtension>(targetExtension);
40     }
41 
ConnectedExtension()42     ConnectedExtension()
43     {
44         extensionType_ = AppExecFwk::ExtensionAbilityType::UNSPECIFIED;
45     }
46 
ConnectedExtension(std::shared_ptr<AbilityRecord> target)47     explicit ConnectedExtension(std::shared_ptr<AbilityRecord> target)
48     {
49         if (!target) {
50             return;
51         }
52         extensionPid_ = target->GetPid();
53         extensionUid_ = target->GetUid();
54         extensionBundleName_ = target->GetAbilityInfo().bundleName;
55         extensionModuleName_ = target->GetAbilityInfo().moduleName;
56         extensionName_ = target->GetAbilityInfo().name;
57         extensionType_ = target->GetAbilityInfo().extensionAbilityType;
58         if (target->GetAbilityInfo().type == AppExecFwk::AbilityType::SERVICE) {
59             extensionType_ = AppExecFwk::ExtensionAbilityType::SERVICE;
60         } else if (target->GetAbilityInfo().type == AppExecFwk::AbilityType::DATA) {
61             extensionType_ = AppExecFwk::ExtensionAbilityType::DATASHARE;
62         }
63     }
64 
65     virtual ~ConnectedExtension() = default;
66 
AddConnection(sptr<IRemoteObject> connection)67     bool AddConnection(sptr<IRemoteObject> connection)
68     {
69         if (!connection) {
70             return false;
71         }
72 
73         std::lock_guard guard(connectionsMutex_);
74         bool needNotify = connections_.empty();
75         connections_.emplace(connection);
76 
77         return needNotify;
78     }
79 
RemoveConnection(sptr<IRemoteObject> connection)80     bool RemoveConnection(sptr<IRemoteObject> connection)
81     {
82         if (!connection) {
83             return false;
84         }
85         std::lock_guard guard(connectionsMutex_);
86         connections_.erase(connection);
87         return connections_.empty();
88     }
89 
GenerateExtensionInfo(AbilityRuntime::ConnectionData & data)90     void GenerateExtensionInfo(AbilityRuntime::ConnectionData &data)
91     {
92         data.extensionPid = extensionPid_;
93         data.extensionUid = extensionUid_;
94         data.extensionBundleName = extensionBundleName_;
95         data.extensionModuleName = extensionModuleName_;
96         data.extensionName = extensionName_;
97         data.extensionType = extensionType_;
98     }
99 
100 private:
101     int32_t extensionPid_ = 0;
102     int32_t extensionUid_ = 0;
103     std::string extensionBundleName_;
104     std::string extensionModuleName_;
105     std::string extensionName_;
106     AppExecFwk::ExtensionAbilityType extensionType_;
107 
108     std::mutex connectionsMutex_;
109     std::set<sptr<IRemoteObject>> connections_; // remote object of IAbilityConnection
110 };
111 
112 /**
113  * @class ConnectedDataAbility
114  * ConnectedDataAbility,This class is used to record a connected data ability.
115  */
116 class ConnectedDataAbility : public std::enable_shared_from_this<ConnectedDataAbility> {
117 public:
CreateConnectedDataAbility(const std::shared_ptr<DataAbilityRecord> & record)118     static std::shared_ptr<ConnectedDataAbility> CreateConnectedDataAbility(
119         const std::shared_ptr<DataAbilityRecord> &record)
120     {
121         if (!record) {
122             return nullptr;
123         }
124 
125         auto targetAbility = record->GetAbilityRecord();
126         if (!targetAbility) {
127             return nullptr;
128         }
129 
130         return std::make_shared<ConnectedDataAbility>(targetAbility);
131     }
132 
ConnectedDataAbility()133     ConnectedDataAbility() {}
134 
ConnectedDataAbility(const std::shared_ptr<AbilityRecord> & target)135     explicit ConnectedDataAbility(const std::shared_ptr<AbilityRecord> &target)
136     {
137         if (!target) {
138             return;
139         }
140 
141         dataAbilityPid_ = target->GetPid();
142         dataAbilityUid_ = target->GetUid();
143         bundleName_ = target->GetAbilityInfo().bundleName;
144         moduleName_ = target->GetAbilityInfo().moduleName;
145         abilityName_ = target->GetAbilityInfo().name;
146     }
147 
148     virtual ~ConnectedDataAbility() = default;
149 
AddCaller(const DataAbilityCaller & caller)150     bool AddCaller(const DataAbilityCaller &caller)
151     {
152         if (!caller.isNotHap && !caller.callerToken) {
153             return false;
154         }
155 
156         bool needNotify = callers_.empty();
157         auto it = find_if(callers_.begin(), callers_.end(), [&caller](const std::shared_ptr<CallerInfo> &info) {
158             if (caller.isNotHap) {
159                 return info && info->IsNotHap() && info->GetCallerPid() == caller.callerPid;
160             } else {
161                 return info && info->GetCallerToken() == caller.callerToken;
162             }
163         });
164         if (it == callers_.end()) {
165             callers_.emplace_back(std::make_shared<CallerInfo>(caller.isNotHap, caller.callerPid, caller.callerToken));
166         }
167 
168         return needNotify;
169     }
170 
RemoveCaller(const DataAbilityCaller & caller)171     bool RemoveCaller(const DataAbilityCaller &caller)
172     {
173         if (!caller.isNotHap && !caller.callerToken) {
174             return false;
175         }
176 
177         auto it = find_if(callers_.begin(), callers_.end(), [&caller](const std::shared_ptr<CallerInfo> &info) {
178             if (caller.isNotHap) {
179                 return info && info->IsNotHap() && info->GetCallerPid() == caller.callerPid;
180             } else {
181                 return info && info->GetCallerToken() == caller.callerToken;
182             }
183         });
184         if (it != callers_.end()) {
185             callers_.erase(it);
186         }
187 
188         return callers_.empty();
189     }
190 
GenerateExtensionInfo(AbilityRuntime::ConnectionData & data)191     void GenerateExtensionInfo(AbilityRuntime::ConnectionData &data)
192     {
193         data.extensionPid = dataAbilityPid_;
194         data.extensionUid = dataAbilityUid_;
195         data.extensionBundleName = bundleName_;
196         data.extensionModuleName = moduleName_;
197         data.extensionName = abilityName_;
198         data.extensionType = AppExecFwk::ExtensionAbilityType::DATASHARE;
199     }
200 
201 private:
202     class CallerInfo : public std::enable_shared_from_this<CallerInfo> {
203     public:
CallerInfo(bool isNotHap,int32_t callerPid,const sptr<IRemoteObject> & callerToken)204         CallerInfo(bool isNotHap, int32_t callerPid, const sptr<IRemoteObject> &callerToken)
205             : isNotHap_(isNotHap), callerPid_(callerPid), callerToken_(callerToken) {}
206 
IsNotHap() const207         bool IsNotHap() const
208         {
209             return isNotHap_;
210         }
211 
GetCallerPid() const212         int32_t GetCallerPid() const
213         {
214             return callerPid_;
215         }
216 
GetCallerToken() const217         sptr<IRemoteObject> GetCallerToken() const
218         {
219             return callerToken_;
220         }
221 
222     private:
223         bool isNotHap_ = false;
224         int32_t callerPid_ = 0;
225         sptr<IRemoteObject> callerToken_ = nullptr;
226     };
227 
228     int32_t dataAbilityPid_ = 0;
229     int32_t dataAbilityUid_ = 0;
230     std::string bundleName_;
231     std::string moduleName_;
232     std::string abilityName_;
233     std::list<std::shared_ptr<CallerInfo>> callers_; // caller infos of this data ability.
234 };
235 
ConnectionStateItem(int32_t callerUid,int32_t callerPid,const std::string & callerName)236 ConnectionStateItem::ConnectionStateItem(int32_t callerUid, int32_t callerPid, const std::string &callerName)
237     : callerUid_(callerUid), callerPid_(callerPid), callerName_(callerName)
238 {
239 }
240 
~ConnectionStateItem()241 ConnectionStateItem::~ConnectionStateItem()
242 {}
243 
CreateConnectionStateItem(const std::shared_ptr<ConnectionRecord> & record)244 std::shared_ptr<ConnectionStateItem> ConnectionStateItem::CreateConnectionStateItem(
245     const std::shared_ptr<ConnectionRecord> &record)
246 {
247     if (!record) {
248         return nullptr;
249     }
250 
251     return std::make_shared<ConnectionStateItem>(record->GetCallerUid(),
252         record->GetCallerPid(), record->GetCallerName());
253 }
254 
CreateConnectionStateItem(const DataAbilityCaller & dataCaller)255 std::shared_ptr<ConnectionStateItem> ConnectionStateItem::CreateConnectionStateItem(
256     const DataAbilityCaller &dataCaller)
257 {
258     return std::make_shared<ConnectionStateItem>(dataCaller.callerUid,
259         dataCaller.callerPid, dataCaller.callerName);
260 }
261 
AddConnection(std::shared_ptr<ConnectionRecord> record,AbilityRuntime::ConnectionData & data)262 bool ConnectionStateItem::AddConnection(std::shared_ptr<ConnectionRecord> record,
263     AbilityRuntime::ConnectionData &data)
264 {
265     if (!record) {
266         TAG_LOGE(AAFwkTag::CONNECTION, "invalid connection record");
267         return false;
268     }
269 
270     auto token = record->GetTargetToken();
271     if (!token) {
272         TAG_LOGE(AAFwkTag::CONNECTION, "invalid token");
273         return false;
274     }
275 
276     sptr<IRemoteObject> connectionObj = record->GetConnection();
277     if (!connectionObj) {
278         TAG_LOGE(AAFwkTag::CONNECTION, "no connection callback");
279         return false;
280     }
281 
282     std::shared_ptr<ConnectedExtension> connectedExtension = nullptr;
283     auto it = connectionMap_.find(token);
284     if (it == connectionMap_.end()) {
285         connectedExtension = ConnectedExtension::CreateConnectedExtension(record);
286         if (connectedExtension) {
287             connectionMap_[token] = connectedExtension;
288         }
289     } else {
290         connectedExtension = it->second;
291     }
292 
293     if (!connectedExtension) {
294         TAG_LOGE(AAFwkTag::CONNECTION, "invalid connectedExtension");
295         return false;
296     }
297 
298     bool needNotify = connectedExtension->AddConnection(connectionObj);
299     if (needNotify) {
300         GenerateConnectionData(connectedExtension, data);
301     }
302 
303     return needNotify;
304 }
305 
RemoveConnection(std::shared_ptr<ConnectionRecord> record,AbilityRuntime::ConnectionData & data)306 bool ConnectionStateItem::RemoveConnection(std::shared_ptr<ConnectionRecord> record,
307     AbilityRuntime::ConnectionData &data)
308 {
309     if (!record) {
310         TAG_LOGE(AAFwkTag::CONNECTION, "invalid connection record");
311         return false;
312     }
313 
314     auto token = record->GetTargetToken();
315     if (!token) {
316         TAG_LOGE(AAFwkTag::CONNECTION, "invalid token");
317         return false;
318     }
319 
320     sptr<IRemoteObject> connectionObj = record->GetConnection();
321     if (!connectionObj) {
322         TAG_LOGE(AAFwkTag::CONNECTION, "no connection callback");
323         return false;
324     }
325 
326     auto it = connectionMap_.find(token);
327     if (it == connectionMap_.end()) {
328         TAG_LOGE(AAFwkTag::CONNECTION, "no such connectedExtension");
329         return false;
330     }
331 
332     auto connectedExtension = it->second;
333     if (!connectedExtension) {
334         TAG_LOGE(AAFwkTag::CONNECTION, "no such connectedExtension");
335         return false;
336     }
337 
338     bool needNotify = connectedExtension->RemoveConnection(connectionObj);
339     if (needNotify) {
340         connectionMap_.erase(it);
341         GenerateConnectionData(connectedExtension, data);
342     }
343 
344     return needNotify;
345 }
346 
AddDataAbilityConnection(const DataAbilityCaller & caller,const std::shared_ptr<DataAbilityRecord> & dataAbility,AbilityRuntime::ConnectionData & data)347 bool ConnectionStateItem::AddDataAbilityConnection(const DataAbilityCaller &caller,
348     const std::shared_ptr<DataAbilityRecord> &dataAbility, AbilityRuntime::ConnectionData &data)
349 {
350     if (!dataAbility) {
351         TAG_LOGE(AAFwkTag::CONNECTION, "invalid dataAbility");
352         return false;
353     }
354 
355     auto token = dataAbility->GetToken();
356     if (!token) {
357         TAG_LOGE(AAFwkTag::CONNECTION, "invalid dataAbility token");
358         return false;
359     }
360 
361     std::shared_ptr<ConnectedDataAbility> connectedAbility = nullptr;
362     auto it = dataAbilityMap_.find(token);
363     if (it == dataAbilityMap_.end()) {
364         connectedAbility = ConnectedDataAbility::CreateConnectedDataAbility(dataAbility);
365         if (connectedAbility) {
366             dataAbilityMap_[token] = connectedAbility;
367         }
368     } else {
369         connectedAbility = it->second;
370     }
371 
372     if (!connectedAbility) {
373         TAG_LOGE(AAFwkTag::CONNECTION, "invalid connectedAbility");
374         return false;
375     }
376 
377     bool needNotify = connectedAbility->AddCaller(caller);
378     if (needNotify) {
379         GenerateConnectionData(connectedAbility, data);
380     }
381 
382     return needNotify;
383 }
384 
RemoveDataAbilityConnection(const DataAbilityCaller & caller,const std::shared_ptr<DataAbilityRecord> & dataAbility,AbilityRuntime::ConnectionData & data)385 bool ConnectionStateItem::RemoveDataAbilityConnection(const DataAbilityCaller &caller,
386     const std::shared_ptr<DataAbilityRecord> &dataAbility, AbilityRuntime::ConnectionData &data)
387 {
388     if (!dataAbility) {
389         TAG_LOGE(AAFwkTag::CONNECTION, "invalid data ability record");
390         return false;
391     }
392 
393     auto token = dataAbility->GetToken();
394     if (!token) {
395         TAG_LOGE(AAFwkTag::CONNECTION, "invalid data ability token");
396         return false;
397     }
398 
399     auto it = dataAbilityMap_.find(token);
400     if (it == dataAbilityMap_.end()) {
401         TAG_LOGE(AAFwkTag::CONNECTION, "no such connected data ability");
402         return false;
403     }
404 
405     auto connectedDataAbility = it->second;
406     if (!connectedDataAbility) {
407         TAG_LOGE(AAFwkTag::CONNECTION, "no such connectedDataAbility");
408         return false;
409     }
410 
411     bool needNotify = connectedDataAbility->RemoveCaller(caller);
412     if (needNotify) {
413         dataAbilityMap_.erase(it);
414         GenerateConnectionData(connectedDataAbility, data);
415     }
416 
417     return needNotify;
418 }
419 
HandleDataAbilityDied(const sptr<IRemoteObject> & token,AbilityRuntime::ConnectionData & data)420 bool ConnectionStateItem::HandleDataAbilityDied(const sptr<IRemoteObject> &token,
421     AbilityRuntime::ConnectionData &data)
422 {
423     if (!token) {
424         return false;
425     }
426 
427     auto it = dataAbilityMap_.find(token);
428     if (it == dataAbilityMap_.end()) {
429         TAG_LOGE(AAFwkTag::CONNECTION, "no such data ability");
430         return false;
431     }
432 
433     auto connectedDataAbility = it->second;
434     if (!connectedDataAbility) {
435         TAG_LOGE(AAFwkTag::CONNECTION, "no connectedDataAbility");
436         return false;
437     }
438 
439     dataAbilityMap_.erase(it);
440     GenerateConnectionData(connectedDataAbility, data);
441     return true;
442 }
443 
IsEmpty() const444 bool ConnectionStateItem::IsEmpty() const
445 {
446     return connectionMap_.empty() && dataAbilityMap_.empty();
447 }
448 
GenerateAllConnectionData(std::vector<AbilityRuntime::ConnectionData> & datas)449 void ConnectionStateItem::GenerateAllConnectionData(std::vector<AbilityRuntime::ConnectionData> &datas)
450 {
451     AbilityRuntime::ConnectionData data;
452     for (auto it = connectionMap_.begin(); it != connectionMap_.end(); ++it) {
453         GenerateConnectionData(it->second, data);
454         datas.emplace_back(data);
455     }
456 
457     for (auto it = dataAbilityMap_.begin(); it != dataAbilityMap_.end(); ++it) {
458         GenerateConnectionData(it->second, data);
459         datas.emplace_back(data);
460     }
461 }
462 
GenerateConnectionData(const std::shared_ptr<ConnectedExtension> & connectedExtension,AbilityRuntime::ConnectionData & data)463 void ConnectionStateItem::GenerateConnectionData(
464     const std::shared_ptr<ConnectedExtension> &connectedExtension, AbilityRuntime::ConnectionData &data)
465 {
466     if (connectedExtension) {
467         connectedExtension->GenerateExtensionInfo(data);
468     }
469     data.callerUid = callerUid_;
470     data.callerPid = callerPid_;
471     data.callerName = callerName_;
472 }
473 
GenerateConnectionData(const std::shared_ptr<ConnectedDataAbility> & connectedDataAbility,AbilityRuntime::ConnectionData & data)474 void ConnectionStateItem::GenerateConnectionData(const std::shared_ptr<ConnectedDataAbility> &connectedDataAbility,
475     AbilityRuntime::ConnectionData &data)
476 {
477     if (connectedDataAbility) {
478         connectedDataAbility->GenerateExtensionInfo(data);
479     }
480     data.callerUid = callerUid_;
481     data.callerPid = callerPid_;
482     data.callerName = callerName_;
483 }
484 }  // namespace AAFwk
485 }  // namespace OHOS
486