1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <securec.h>
17 #include <string>
18
19 #include "hilog/log.h"
20 #include "netmgr_ext_log_wrapper.h"
21 #include "netfirewall_db_helper.h"
22
23 using namespace OHOS::NativeRdb;
24 namespace {
25 const std::string DATABASE_ID = "id";
26 const std::string RULE_ID = "ruleId";
27 const std::string DOMAIN_NUM = "domainNum";
28 const std::string FUZZY_NUM = "fuzzyDomainNum";
29 const std::string SQL_SUM = "SELECT SUM(";
30 const std::string SQL_FROM = ") FROM ";
31 }
32
33 namespace OHOS {
34 namespace NetManagerStandard {
NetFirewallDbHelper()35 NetFirewallDbHelper::NetFirewallDbHelper()
36 {
37 firewallDatabase_ = NetFirewallDataBase::GetInstance();
38 }
39
~NetFirewallDbHelper()40 NetFirewallDbHelper::~NetFirewallDbHelper()
41 {
42 firewallDatabase_ = nullptr;
43 }
44
GetInstance()45 NetFirewallDbHelper &NetFirewallDbHelper::GetInstance()
46 {
47 static NetFirewallDbHelper instance;
48 return instance;
49 }
50
DomainListToBlob(const std::vector<NetFirewallDomainParam> & vec,std::vector<uint8_t> & blob,uint32_t & fuzzyNum)51 bool NetFirewallDbHelper::DomainListToBlob(const std::vector<NetFirewallDomainParam> &vec, std::vector<uint8_t> &blob,
52 uint32_t &fuzzyNum)
53 {
54 blob.clear();
55 for (const auto ¶m : vec) {
56 if (param.isWildcard) {
57 fuzzyNum++;
58 }
59 // 1 put isWildcard
60 blob.emplace_back(param.isWildcard ? 1 : 0);
61 // 2 for those with a string type, calculate the string size
62 uint16_t size = (uint16_t)(param.domain.length());
63 uint8_t *sizePtr = (uint8_t *)&size;
64 blob.emplace_back(sizePtr[0]);
65 blob.emplace_back(sizePtr[1]);
66 // 3 store string
67 std::vector<uint8_t> domain(param.domain.begin(), param.domain.end());
68 blob.insert(blob.end(), domain.begin(), domain.end());
69 }
70 return blob.size() > 0;
71 }
72
BlobToDomainList(const std::vector<uint8_t> & blob,std::vector<NetFirewallDomainParam> & vec)73 bool NetFirewallDbHelper::BlobToDomainList(const std::vector<uint8_t> &blob, std::vector<NetFirewallDomainParam> &vec)
74 {
75 vec.clear();
76 size_t blobSize = blob.size();
77 if (blobSize < 1) {
78 return false;
79 }
80
81 size_t i = 0;
82 size_t lenSize = sizeof(uint16_t);
83 while (i < blobSize) {
84 NetFirewallDomainParam param;
85 // 1 get isWildcard
86 param.isWildcard = blob[i] ? true : false;
87 // 2 get size
88 i++;
89 if (i >= blobSize || (blobSize - i) < lenSize) {
90 return true;
91 }
92 const uint8_t *sizePtr = &blob[i];
93 uint16_t size = *((uint16_t *)sizePtr);
94 int index = i + lenSize;
95 if (index >= blobSize || (blobSize - index) < size) {
96 return true;
97 }
98 // 3 get string
99 auto it = blob.begin() + index;
100 param.domain = std::string(it, it + size);
101 vec.emplace_back(param);
102 i += size + lenSize;
103 }
104
105 return vec.size() > 0;
106 }
107
ListToBlob(const std::vector<T> & vec,std::vector<uint8_t> & blob)108 template <typename T> void NetFirewallDbHelper::ListToBlob(const std::vector<T> &vec, std::vector<uint8_t> &blob)
109 {
110 blob.clear();
111 size_t size = sizeof(T);
112 for (const auto ¶m : vec) {
113 const uint8_t *data = reinterpret_cast<const uint8_t *>(¶m);
114 std::vector<uint8_t> item(data, data + size);
115 // 1 store each object
116 blob.insert(blob.end(), item.begin(), item.end());
117 }
118 }
119
BlobToList(const std::vector<uint8_t> & blob,std::vector<T> & vec)120 template <typename T> void NetFirewallDbHelper::BlobToList(const std::vector<uint8_t> &blob, std::vector<T> &vec)
121 {
122 vec.clear();
123 size_t blobSize = blob.size();
124 if (blobSize < 1) {
125 return;
126 }
127
128 size_t i = 0;
129 size_t size = sizeof(T);
130 while (i < blobSize) {
131 if ((blobSize - i) < size) {
132 return;
133 }
134 T value;
135 memset_s(&value, size, 0, size);
136 memcpy_s(&value, size, &blob[i], size);
137 vec.emplace_back(value);
138 i += size;
139 }
140 }
141
FillValuesOfFirewallRule(ValuesBucket & values,const NetFirewallRule & rule)142 int32_t NetFirewallDbHelper::FillValuesOfFirewallRule(ValuesBucket &values, const NetFirewallRule &rule)
143 {
144 values.Clear();
145
146 values.PutInt(NET_FIREWALL_USER_ID, rule.userId);
147 values.PutString(NET_FIREWALL_RULE_NAME, rule.ruleName);
148 values.PutString(NET_FIREWALL_RULE_DESC, rule.ruleDescription);
149 values.PutInt(NET_FIREWALL_RULE_DIR, static_cast<int32_t>(rule.ruleDirection));
150 values.PutInt(NET_FIREWALL_RULE_ACTION, static_cast<int32_t>(rule.ruleAction));
151 values.PutInt(NET_FIREWALL_RULE_TYPE, static_cast<int32_t>(rule.ruleType));
152 values.PutInt(NET_FIREWALL_IS_ENABLED, rule.isEnabled);
153 values.PutInt(NET_FIREWALL_APP_ID, rule.appUid);
154 std::vector<uint8_t> blob;
155 std::vector<DataBaseIp> dbIPs;
156 std::vector<DataBasePort> dbPorts;
157 switch (rule.ruleType) {
158 case NetFirewallRuleType::RULE_IP: {
159 values.PutInt(NET_FIREWALL_PROTOCOL, static_cast<int32_t>(rule.protocol));
160 FirewallIpToDbIp(rule.localIps, dbIPs);
161 ListToBlob(dbIPs, blob);
162 values.PutBlob(NET_FIREWALL_LOCAL_IP, blob);
163
164 FirewallIpToDbIp(rule.remoteIps, dbIPs);
165 ListToBlob(dbIPs, blob);
166 values.PutBlob(NET_FIREWALL_REMOTE_IP, blob);
167
168 FirewallPortToDbPort(rule.localPorts, dbPorts);
169 ListToBlob(dbPorts, blob);
170 values.PutBlob(NET_FIREWALL_LOCAL_PORT, blob);
171
172 FirewallPortToDbPort(rule.remotePorts, dbPorts);
173 ListToBlob(dbPorts, blob);
174 values.PutBlob(NET_FIREWALL_REMOTE_PORT, blob);
175 break;
176 }
177 case NetFirewallRuleType::RULE_DNS: {
178 values.PutString(NET_FIREWALL_DNS_PRIMARY, rule.dns.primaryDns);
179 values.PutString(NET_FIREWALL_DNS_STANDY, rule.dns.standbyDns);
180 break;
181 }
182 case NetFirewallRuleType::RULE_DOMAIN: {
183 values.PutInt(DOMAIN_NUM, rule.domains.size());
184 uint32_t fuzzyNum = 0;
185 DomainListToBlob(rule.domains, blob, fuzzyNum);
186 values.PutInt(FUZZY_NUM, fuzzyNum);
187 values.PutBlob(NET_FIREWALL_RULE_DOMAIN, blob);
188 break;
189 }
190 default:
191 break;
192 }
193 return FIREWALL_OK;
194 }
195
196
AddFirewallRule(NativeRdb::ValuesBucket & values,const NetFirewallRule & rule)197 int32_t NetFirewallDbHelper::AddFirewallRule(NativeRdb::ValuesBucket &values, const NetFirewallRule &rule)
198 {
199 FillValuesOfFirewallRule(values, rule);
200 return firewallDatabase_->Insert(values, FIREWALL_TABLE_NAME);
201 }
202
AddFirewallRuleRecord(const NetFirewallRule & rule)203 int32_t NetFirewallDbHelper::AddFirewallRuleRecord(const NetFirewallRule &rule)
204 {
205 std::lock_guard<std::mutex> guard(databaseMutex_);
206 ValuesBucket values;
207 int32_t ret = AddFirewallRule(values, rule);
208 if (ret < FIREWALL_OK) {
209 NETMGR_EXT_LOG_E("AddFirewallRule Insert error: %{public}d", ret);
210 (void)firewallDatabase_->RollBack();
211 }
212 return ret;
213 }
214
CheckIfNeedUpdateEx(const std::string & tableName,bool & isUpdate,int32_t ruleId,NetFirewallRule & oldRule)215 int32_t NetFirewallDbHelper::CheckIfNeedUpdateEx(const std::string &tableName, bool &isUpdate, int32_t ruleId,
216 NetFirewallRule &oldRule)
217 {
218 std::vector<std::string> columns;
219 RdbPredicates rdbPredicates(tableName);
220 rdbPredicates.BeginWrap()->EqualTo(RULE_ID, std::to_string(ruleId))->EndWrap();
221 auto resultSet = firewallDatabase_->Query(rdbPredicates, columns);
222 if (resultSet == nullptr) {
223 NETMGR_EXT_LOG_E("Query error");
224 return FIREWALL_RDB_EXECUTE_FAILTURE;
225 }
226 int32_t rowCount = 0;
227 if (resultSet->GetRowCount(rowCount) != E_OK) {
228 NETMGR_EXT_LOG_E("GetRowCount error");
229 return FIREWALL_RDB_EXECUTE_FAILTURE;
230 }
231 std::vector<NetFirewallRule> rules;
232 GetResultRightRecordEx(resultSet, rules);
233 isUpdate = rowCount > 0 && !rules.empty();
234 if (!rules.empty()) {
235 oldRule.ruleId = rules[0].ruleId;
236 oldRule.userId = rules[0].userId;
237 oldRule.ruleType = rules[0].ruleType;
238 oldRule.isEnabled = rules[0].isEnabled;
239 }
240 return FIREWALL_OK;
241 }
242
UpdateFirewallRuleRecord(const NetFirewallRule & rule)243 int32_t NetFirewallDbHelper::UpdateFirewallRuleRecord(const NetFirewallRule &rule)
244 {
245 std::lock_guard<std::mutex> guard(databaseMutex_);
246
247 ValuesBucket values;
248 FillValuesOfFirewallRule(values, rule);
249 int32_t changedRows = 0;
250 int32_t ret = firewallDatabase_->Update(FIREWALL_TABLE_NAME, changedRows, values, "ruleId = ?",
251 std::vector<std::string> { std::to_string(rule.ruleId) });
252 if (ret < FIREWALL_OK) {
253 NETMGR_EXT_LOG_E("Update error: %{public}d", ret);
254 (void)firewallDatabase_->RollBack();
255 }
256 return ret;
257 }
258
GetParamRuleInfoFormResultSet(std::string & columnName,int32_t index,NetFirewallRuleInfo & table)259 void NetFirewallDbHelper::GetParamRuleInfoFormResultSet(std::string &columnName, int32_t index,
260 NetFirewallRuleInfo &table)
261 {
262 if (columnName == NET_FIREWALL_PROTOCOL) {
263 table.protocolIndex = index;
264 return;
265 }
266 if (columnName == NET_FIREWALL_LOCAL_IP) {
267 table.localIpsIndex = index;
268 return;
269 }
270 if (columnName == NET_FIREWALL_REMOTE_IP) {
271 table.remoteIpsIndex = index;
272 return;
273 }
274 if (columnName == NET_FIREWALL_LOCAL_PORT) {
275 table.localPortsIndex = index;
276 return;
277 }
278 if (columnName == NET_FIREWALL_REMOTE_PORT) {
279 table.remotePortsIndex = index;
280 return;
281 }
282 if (columnName == NET_FIREWALL_RULE_DOMAIN) {
283 table.domainsIndex = index;
284 return;
285 }
286 if (columnName == NET_FIREWALL_DNS_PRIMARY) {
287 table.primaryDnsIndex = index;
288 return;
289 }
290 if (columnName == NET_FIREWALL_DNS_STANDY) {
291 table.standbyDnsIndex = index;
292 }
293 }
294
GetResultSetTableInfo(const std::shared_ptr<OHOS::NativeRdb::ResultSet> & resultSet,NetFirewallRuleInfo & table)295 int32_t NetFirewallDbHelper::GetResultSetTableInfo(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
296 NetFirewallRuleInfo &table)
297 {
298 std::vector<std::string> columnNames;
299 if (resultSet->GetRowCount(table.rowCount) != E_OK || resultSet->GetAllColumnNames(columnNames) != E_OK) {
300 NETMGR_EXT_LOG_E("get table info failed");
301 return FIREWALL_RDB_EXECUTE_FAILTURE;
302 }
303 int32_t columnNamesCount = static_cast<int32_t>(columnNames.size());
304 for (int32_t i = 0; i < columnNamesCount; i++) {
305 std::string &columnName = columnNames.at(i);
306 if (columnName == RULE_ID) {
307 table.ruleIdIndex = i;
308 continue;
309 }
310 if (columnName == NET_FIREWALL_USER_ID) {
311 table.userIdIndex = i;
312 continue;
313 }
314 if (columnName == NET_FIREWALL_RULE_NAME) {
315 table.ruleNameIndex = i;
316 continue;
317 }
318 if (columnName == NET_FIREWALL_RULE_DESC) {
319 table.ruleDescriptionIndex = i;
320 continue;
321 }
322 if (columnName == NET_FIREWALL_RULE_DIR) {
323 table.ruleDirectionIndex = i;
324 continue;
325 }
326 if (columnName == NET_FIREWALL_RULE_ACTION) {
327 table.ruleActionIndex = i;
328 continue;
329 }
330 if (columnName == NET_FIREWALL_RULE_TYPE) {
331 table.ruleTypeIndex = i;
332 continue;
333 }
334 if (columnName == NET_FIREWALL_IS_ENABLED) {
335 table.isEnabledIndex = i;
336 continue;
337 }
338 if (columnName == NET_FIREWALL_APP_ID) {
339 table.appUidIndex = i;
340 continue;
341 }
342 GetParamRuleInfoFormResultSet(columnName, i, table);
343 }
344 return FIREWALL_OK;
345 }
346
GetResultSetTableInfo(const std::shared_ptr<OHOS::NativeRdb::ResultSet> & resultSet,NetInterceptRecordInfo & table)347 int32_t NetFirewallDbHelper::GetResultSetTableInfo(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
348 NetInterceptRecordInfo &table)
349 {
350 int32_t rowCount = 0;
351 std::vector<std::string> columnNames;
352 if (resultSet->GetRowCount(rowCount) != E_OK || resultSet->GetAllColumnNames(columnNames) != E_OK) {
353 NETMGR_EXT_LOG_E("get table info failed");
354 return FIREWALL_RDB_EXECUTE_FAILTURE;
355 }
356 int32_t columnNamesCount = static_cast<int32_t>(columnNames.size());
357 for (int32_t i = 0; i < columnNamesCount; i++) {
358 std::string &columnName = columnNames.at(i);
359 if (columnName == NET_FIREWALL_RECORD_TIME) {
360 table.timeIndex = i;
361 continue;
362 }
363 if (columnName == NET_FIREWALL_RECORD_LOCAL_IP) {
364 table.localIpIndex = i;
365 continue;
366 }
367 if (columnName == NET_FIREWALL_RECORD_REMOTE_IP) {
368 table.remoteIpIndex = i;
369 continue;
370 }
371 if (columnName == NET_FIREWALL_RECORD_LOCAL_PORT) {
372 table.localPortIndex = i;
373 continue;
374 }
375 if (columnName == NET_FIREWALL_RECORD_REMOTE_PORT) {
376 table.remotePortIndex = i;
377 continue;
378 }
379 if (columnName == NET_FIREWALL_RECORD_PROTOCOL) {
380 table.protocolIndex = i;
381 continue;
382 }
383 if (columnName == NET_FIREWALL_RECORD_UID) {
384 table.appUidIndex = i;
385 continue;
386 }
387 if (columnName == NET_FIREWALL_DOMAIN) {
388 table.domainIndex = i;
389 }
390 }
391 table.rowCount = rowCount;
392 return FIREWALL_OK;
393 }
394
GetRuleDataFromResultSet(const std::shared_ptr<OHOS::NativeRdb::ResultSet> & resultSet,const NetFirewallRuleInfo & table,NetFirewallRule & info)395 void NetFirewallDbHelper::GetRuleDataFromResultSet(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
396 const NetFirewallRuleInfo &table, NetFirewallRule &info)
397 {
398 resultSet->GetInt(table.userIdIndex, info.userId);
399 resultSet->GetString(table.ruleNameIndex, info.ruleName);
400 resultSet->GetString(table.ruleDescriptionIndex, info.ruleDescription);
401 int ruleDirection = 0;
402 if (resultSet->GetInt(table.ruleDirectionIndex, ruleDirection) == E_OK) {
403 info.ruleDirection = static_cast<NetFirewallRuleDirection>(ruleDirection);
404 }
405 int ruleAction = 0;
406 if (resultSet->GetInt(table.ruleActionIndex, ruleAction) == E_OK) {
407 info.ruleAction = static_cast<FirewallRuleAction>(ruleAction);
408 }
409 int ruleType = 0;
410 if (resultSet->GetInt(table.ruleTypeIndex, ruleType) == E_OK) {
411 info.ruleType = static_cast<NetFirewallRuleType>(ruleType);
412 }
413 int isEnabled = 0;
414 if (resultSet->GetInt(table.isEnabledIndex, isEnabled) == E_OK) {
415 info.isEnabled = static_cast<bool>(isEnabled);
416 }
417 resultSet->GetInt(table.appUidIndex, info.appUid);
418 GetRuleListParamFromResultSet(resultSet, table, info);
419 }
420
GetRuleListParamFromResultSet(const std::shared_ptr<OHOS::NativeRdb::ResultSet> & resultSet,const NetFirewallRuleInfo & table,NetFirewallRule & info)421 void NetFirewallDbHelper::GetRuleListParamFromResultSet(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
422 const NetFirewallRuleInfo &table, NetFirewallRule &info)
423 {
424 std::vector<uint8_t> value;
425 std::vector<DataBaseIp> dbIPs;
426 std::vector<DataBasePort> dbPorts;
427 switch (info.ruleType) {
428 case NetFirewallRuleType::RULE_IP: {
429 int protocol = 0;
430 if (resultSet->GetInt(table.protocolIndex, protocol) == E_OK) {
431 info.protocol = static_cast<NetworkProtocol>(protocol);
432 }
433 resultSet->GetBlob(table.localIpsIndex, value);
434 BlobToList(value, dbIPs);
435 DbIpToFirewallIp(dbIPs, info.localIps);
436 value.clear();
437 resultSet->GetBlob(table.remoteIpsIndex, value);
438 BlobToList(value, dbIPs);
439 DbIpToFirewallIp(dbIPs, info.remoteIps);
440 value.clear();
441 resultSet->GetBlob(table.localPortsIndex, value);
442 BlobToList(value, dbPorts);
443 DbPortToFirewallPort(dbPorts, info.localPorts);
444 value.clear();
445 resultSet->GetBlob(table.remotePortsIndex, value);
446 BlobToList(value, dbPorts);
447 DbPortToFirewallPort(dbPorts, info.remotePorts);
448 break;
449 }
450 case NetFirewallRuleType::RULE_DNS: {
451 resultSet->GetString(table.primaryDnsIndex, info.dns.primaryDns);
452 resultSet->GetString(table.standbyDnsIndex, info.dns.standbyDns);
453 break;
454 }
455
456 case NetFirewallRuleType::RULE_DOMAIN: {
457 resultSet->GetBlob(table.domainsIndex, value);
458 BlobToDomainList(value, info.domains);
459 break;
460 }
461 default:
462 break;
463 }
464 }
465
GetResultRightRecordEx(const std::shared_ptr<OHOS::NativeRdb::ResultSet> & resultSet,std::vector<NetFirewallRule> & rules)466 int32_t NetFirewallDbHelper::GetResultRightRecordEx(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
467 std::vector<NetFirewallRule> &rules)
468 {
469 NetFirewallRuleInfo table;
470 int32_t ret = GetResultSetTableInfo(resultSet, table);
471 if (ret < FIREWALL_OK) {
472 NETMGR_EXT_LOG_E("GetResultSetTableInfo failed");
473 return ret;
474 }
475
476 bool endFlag = false;
477 NetFirewallRule info;
478 for (int32_t i = 0; (i < table.rowCount) && !endFlag; i++) {
479 if (resultSet->GoToRow(i) != E_OK) {
480 NETMGR_EXT_LOG_E("GoToRow %{public}d", i);
481 break;
482 }
483 resultSet->GetInt(table.ruleIdIndex, info.ruleId);
484 if (info.ruleId > 0) {
485 GetRuleDataFromResultSet(resultSet, table, info);
486 rules.emplace_back(std::move(info));
487 }
488
489 resultSet->IsEnded(endFlag);
490 }
491 resultSet->Close();
492 return rules.size();
493 }
494
GetResultRightRecordEx(const std::shared_ptr<OHOS::NativeRdb::ResultSet> & resultSet,std::vector<InterceptRecord> & rules)495 int32_t NetFirewallDbHelper::GetResultRightRecordEx(const std::shared_ptr<OHOS::NativeRdb::ResultSet> &resultSet,
496 std::vector<InterceptRecord> &rules)
497 {
498 NetInterceptRecordInfo table;
499 int32_t ret = GetResultSetTableInfo(resultSet, table);
500 if (ret < FIREWALL_OK) {
501 NETMGR_EXT_LOG_E("GetResultSetTableInfo failed");
502 return ret;
503 }
504
505 bool endFlag = false;
506 int32_t localPort = 0;
507 int32_t remotePort = 0;
508 int32_t protocol = 0;
509 InterceptRecord info;
510 for (int32_t i = 0; (i < table.rowCount) && !endFlag; i++) {
511 if (resultSet->GoToRow(i) != E_OK) {
512 NETMGR_EXT_LOG_E("GetResultRightRecordEx GoToRow %{public}d", i);
513 break;
514 }
515 resultSet->GetInt(table.timeIndex, info.time);
516 resultSet->GetString(table.localIpIndex, info.localIp);
517 resultSet->GetString(table.remoteIpIndex, info.remoteIp);
518 if (resultSet->GetInt(table.localPortIndex, localPort) == E_OK) {
519 info.localPort = static_cast<uint16_t>(localPort);
520 }
521 if (resultSet->GetInt(table.remotePortIndex, remotePort) == E_OK) {
522 info.remotePort = static_cast<uint16_t>(remotePort);
523 }
524 if (resultSet->GetInt(table.protocolIndex, protocol) == E_OK) {
525 info.protocol = static_cast<uint16_t>(protocol);
526 }
527 resultSet->GetInt(table.appUidIndex, info.appUid);
528 resultSet->GetString(table.domainIndex, info.domain);
529 if (info.time > 0) {
530 rules.emplace_back(std::move(info));
531 }
532 resultSet->IsEnded(endFlag);
533 }
534 int32_t index = 0;
535 resultSet->GetRowIndex(index);
536 resultSet->IsEnded(endFlag);
537 NETMGR_EXT_LOG_I("row=%{public}d pos=%{public}d ret=%{public}zu end=%{public}s", table.rowCount, index,
538 rules.size(), (endFlag ? "yes" : "no"));
539
540 resultSet->Close();
541 return rules.size();
542 }
543
544 template <typename T>
QueryAndGetResult(const NativeRdb::RdbPredicates & rdbPredicates,const std::vector<std::string> & columns,std::vector<T> & rules)545 int32_t NetFirewallDbHelper::QueryAndGetResult(const NativeRdb::RdbPredicates &rdbPredicates,
546 const std::vector<std::string> &columns, std::vector<T> &rules)
547 {
548 auto resultSet = firewallDatabase_->Query(rdbPredicates, columns);
549 if (resultSet == nullptr) {
550 NETMGR_EXT_LOG_E("Query error");
551 return FIREWALL_RDB_EXECUTE_FAILTURE;
552 }
553 return GetResultRightRecordEx(resultSet, rules);
554 }
555
QueryAllFirewallRuleRecord(std::vector<NetFirewallRule> & rules)556 int32_t NetFirewallDbHelper::QueryAllFirewallRuleRecord(std::vector<NetFirewallRule> &rules)
557 {
558 std::lock_guard<std::mutex> guard(databaseMutex_);
559 NETMGR_EXT_LOG_I("Query detail: all user");
560 std::vector<std::string> columns;
561 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
562 return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
563 }
564
QueryAllUserEnabledFirewallRules(std::vector<NetFirewallRule> & rules,NetFirewallRuleType type)565 int32_t NetFirewallDbHelper::QueryAllUserEnabledFirewallRules(std::vector<NetFirewallRule> &rules,
566 NetFirewallRuleType type)
567 {
568 std::lock_guard<std::mutex> guard(databaseMutex_);
569 NETMGR_EXT_LOG_I("Query detail: all user");
570 std::vector<std::string> columns;
571 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
572 rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_IS_ENABLED, "1");
573 if (type != NetFirewallRuleType::RULE_ALL && type != NetFirewallRuleType::RULE_INVALID) {
574 rdbPredicates.And()->EqualTo(NET_FIREWALL_RULE_TYPE, std::to_string(static_cast<int32_t>(type)));
575 }
576 rdbPredicates.EndWrap();
577 return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
578 }
579
QueryEnabledFirewallRules(int32_t userId,int32_t appUid,std::vector<NetFirewallRule> & rules)580 int32_t NetFirewallDbHelper::QueryEnabledFirewallRules(int32_t userId, int32_t appUid,
581 std::vector<NetFirewallRule> &rules)
582 {
583 std::lock_guard<std::mutex> guard(databaseMutex_);
584 NETMGR_EXT_LOG_I("QueryEnabledFirewallRules : userId=%{public}d ", userId);
585 std::vector<std::string> columns;
586 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
587 rdbPredicates.BeginWrap()
588 ->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))
589 ->And()
590 ->EqualTo(NET_FIREWALL_IS_ENABLED, "1")
591 ->And()
592 ->EqualTo(NET_FIREWALL_APP_ID, appUid)
593 ->EndWrap();
594 return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
595 }
596
QueryFirewallRuleRecord(int32_t ruleId,int32_t userId,std::vector<NetFirewallRule> & rules)597 int32_t NetFirewallDbHelper::QueryFirewallRuleRecord(int32_t ruleId, int32_t userId,
598 std::vector<NetFirewallRule> &rules)
599 {
600 std::lock_guard<std::mutex> guard(databaseMutex_);
601 NETMGR_EXT_LOG_I("Query detail: ruleId=%{public}d userId=%{public}d", ruleId, userId);
602 std::vector<std::string> columns;
603 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
604 rdbPredicates.BeginWrap()
605 ->EqualTo(RULE_ID, std::to_string(ruleId))
606 ->And()
607 ->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))
608 ->EndWrap();
609
610 return QueryFirewallRuleRecord(rdbPredicates, columns, rules);
611 }
612
QueryFirewallRuleRecord(const NativeRdb::RdbPredicates & rdbPredicates,const std::vector<std::string> & columns,std::vector<NetFirewallRule> & rules)613 int32_t NetFirewallDbHelper::QueryFirewallRuleRecord(const NativeRdb::RdbPredicates &rdbPredicates,
614 const std::vector<std::string> &columns, std::vector<NetFirewallRule> &rules)
615 {
616 int32_t ret = QueryAndGetResult(rdbPredicates, columns, rules);
617 if (ret < 0) {
618 NETMGR_EXT_LOG_E("QueryFirewallRuleRecord error.");
619 return ret;
620 }
621 size_t size = rules.size();
622 if (size == 0) {
623 NETMGR_EXT_LOG_I("QueryFirewallRuleRecord rule empty");
624 return FIREWALL_OK;
625 }
626 NETMGR_EXT_LOG_I("QueryFirewallRuleRecord rule size: %{public}zu", size);
627 return FIREWALL_OK;
628 }
629
DeleteAndNoOtherOperation(const std::string & whereClause,const std::vector<std::string> & whereArgs)630 int32_t NetFirewallDbHelper::DeleteAndNoOtherOperation(const std::string &whereClause,
631 const std::vector<std::string> &whereArgs)
632 {
633 int32_t changedRows = 0;
634 int32_t ret = firewallDatabase_->Delete(FIREWALL_TABLE_NAME, changedRows, whereClause, whereArgs);
635 if (ret < FIREWALL_OK) {
636 (void)firewallDatabase_->RollBack();
637 return FIREWALL_FAILURE;
638 }
639 return ret;
640 }
641
DeleteFirewallRuleRecord(int32_t userId,int32_t ruleId)642 int32_t NetFirewallDbHelper::DeleteFirewallRuleRecord(int32_t userId, int32_t ruleId)
643 {
644 std::lock_guard<std::mutex> guard(databaseMutex_);
645 std::string whereClause = { "userId = ? AND ruleId = ?" };
646 std::vector<std::string> whereArgs = { std::to_string(userId), std::to_string(ruleId) };
647 int32_t ret = DeleteAndNoOtherOperation(whereClause, whereArgs);
648 if (ret != FIREWALL_OK) {
649 NETMGR_EXT_LOG_E("failed: detale(ruleId): %{public}d", ret);
650 }
651 return ret;
652 }
653
DeleteFirewallRuleRecordByUserId(int32_t userId)654 int32_t NetFirewallDbHelper::DeleteFirewallRuleRecordByUserId(int32_t userId)
655 {
656 std::lock_guard<std::mutex> guard(databaseMutex_);
657 std::string whereClause = { "userId = ?" };
658 std::vector<std::string> whereArgs = { std::to_string(userId) };
659 int32_t ret = DeleteAndNoOtherOperation(whereClause, whereArgs);
660 if (ret != FIREWALL_OK) {
661 NETMGR_EXT_LOG_E("failed: detale(ruleId): %{public}d", ret);
662 }
663 return ret;
664 }
665
DeleteFirewallRuleRecordByAppId(int32_t appUid)666 int32_t NetFirewallDbHelper::DeleteFirewallRuleRecordByAppId(int32_t appUid)
667 {
668 std::lock_guard<std::mutex> guard(databaseMutex_);
669 std::string whereClause = { "appUid = ?" };
670 std::vector<std::string> whereArgs = { std::to_string(appUid) };
671 int32_t ret = DeleteAndNoOtherOperation(whereClause, whereArgs);
672 if (ret != FIREWALL_OK) {
673 NETMGR_EXT_LOG_E("failed: detale(ruleId): %{public}d", ret);
674 }
675 return ret;
676 }
677
IsFirewallRuleExist(int32_t ruleId,NetFirewallRule & oldRule)678 bool NetFirewallDbHelper::IsFirewallRuleExist(int32_t ruleId, NetFirewallRule &oldRule)
679 {
680 std::lock_guard<std::mutex> guard(databaseMutex_);
681 bool isExist = false;
682 int32_t ret = CheckIfNeedUpdateEx(FIREWALL_TABLE_NAME, isExist, ruleId, oldRule);
683 if (ret < FIREWALL_OK) {
684 NETMGR_EXT_LOG_E("check if need update error: %{public}d", ret);
685 }
686 return isExist;
687 }
688
QueryFirewallRuleByUserIdCount(int32_t userId,int64_t & rowCount)689 int32_t NetFirewallDbHelper::QueryFirewallRuleByUserIdCount(int32_t userId, int64_t &rowCount)
690 {
691 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
692 rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
693
694 return Count(rowCount, rdbPredicates);
695 }
696
QueryFirewallRuleAllCount(int64_t & rowCount)697 int32_t NetFirewallDbHelper::QueryFirewallRuleAllCount(int64_t &rowCount)
698 {
699 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
700 return Count(rowCount, rdbPredicates);
701 }
702
QueryFirewallRuleAllDomainCount()703 int32_t NetFirewallDbHelper::QueryFirewallRuleAllDomainCount()
704 {
705 return QuerySql(SQL_SUM + DOMAIN_NUM + SQL_FROM + FIREWALL_TABLE_NAME);
706 }
707
QueryFirewallRuleAllFuzzyDomainCount()708 int32_t NetFirewallDbHelper::QueryFirewallRuleAllFuzzyDomainCount()
709 {
710 return QuerySql(SQL_SUM + FUZZY_NUM + SQL_FROM + FIREWALL_TABLE_NAME);
711 }
712
QueryFirewallRuleDomainByUserIdCount(int32_t userId)713 int32_t NetFirewallDbHelper::QueryFirewallRuleDomainByUserIdCount(int32_t userId)
714 {
715 return QuerySql(SQL_SUM + DOMAIN_NUM + SQL_FROM + FIREWALL_TABLE_NAME + " WHERE (" + NET_FIREWALL_USER_ID + " = " +
716 std::to_string(userId) + ")");
717 }
718
QueryFirewallRule(const int32_t userId,const sptr<RequestParam> & requestParam,sptr<FirewallRulePage> & info)719 int32_t NetFirewallDbHelper::QueryFirewallRule(const int32_t userId, const sptr<RequestParam> &requestParam,
720 sptr<FirewallRulePage> &info)
721 {
722 std::lock_guard<std::mutex> guard(databaseMutex_);
723 int64_t rowCount = 0;
724 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
725 rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
726 firewallDatabase_->Count(rowCount, rdbPredicates);
727 info->totalPage = rowCount / requestParam->pageSize;
728 int32_t remainder = rowCount % requestParam->pageSize;
729 if (remainder > 0) {
730 info->totalPage += 1;
731 }
732 NETMGR_EXT_LOG_I("QueryFirewallRule: userId=%{public}d page=%{public}d pageSize=%{public}d total=%{public}d",
733 userId, requestParam->page, requestParam->pageSize, info->totalPage);
734 if (info->totalPage < requestParam->page) {
735 return FIREWALL_FAILURE;
736 }
737 std::vector<std::string> columns;
738 rdbPredicates.Clear();
739 rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId));
740 if (requestParam->orderType == NetFirewallOrderType::ORDER_ASC) {
741 rdbPredicates.OrderByAsc(NET_FIREWALL_RULE_NAME);
742 } else {
743 rdbPredicates.OrderByDesc(NET_FIREWALL_RULE_NAME);
744 }
745 rdbPredicates.Limit((requestParam->page - 1) * requestParam->pageSize, requestParam->pageSize)->EndWrap();
746 return QueryFirewallRuleRecord(rdbPredicates, columns, info->data);
747 }
748
Count(int64_t & outValue,const OHOS::NativeRdb::AbsRdbPredicates & predicates)749 int32_t NetFirewallDbHelper::Count(int64_t &outValue, const OHOS::NativeRdb::AbsRdbPredicates &predicates)
750 {
751 std::lock_guard<std::mutex> guard(databaseMutex_);
752 int32_t ret = firewallDatabase_->Count(outValue, predicates);
753 if (ret < FIREWALL_OK) {
754 NETMGR_EXT_LOG_E("Count error");
755 return -1;
756 }
757 return ret;
758 }
759
QuerySql(const std::string & sql)760 int32_t NetFirewallDbHelper::QuerySql(const std::string &sql)
761 {
762 std::lock_guard<std::mutex> guard(databaseMutex_);
763 std::vector<std::string> selectionArgs;
764 auto resultSet = firewallDatabase_->QuerySql(sql, selectionArgs);
765 if (resultSet == nullptr) {
766 NETMGR_EXT_LOG_E("QuerySql error");
767 return FIREWALL_RDB_EXECUTE_FAILTURE;
768 }
769 int32_t rowCount = 0;
770 if (resultSet->GetRowCount(rowCount) != E_OK || resultSet->GoToRow(0) != E_OK) {
771 return FIREWALL_RDB_EXECUTE_FAILTURE;
772 }
773 int32_t value = 0;
774 resultSet->GetInt(0, value);
775 return value;
776 }
777
IsDnsRuleExist(const sptr<NetFirewallRule> & rule)778 bool NetFirewallDbHelper::IsDnsRuleExist(const sptr<NetFirewallRule> &rule)
779 {
780 if (rule->ruleType != NetFirewallRuleType::RULE_DNS) {
781 return false;
782 }
783 std::lock_guard<std::mutex> guard(databaseMutex_);
784 RdbPredicates rdbPredicates(FIREWALL_TABLE_NAME);
785 rdbPredicates.BeginWrap()
786 ->EqualTo(NET_FIREWALL_USER_ID, std::to_string(rule->userId))
787 ->And()
788 ->EqualTo(NET_FIREWALL_RULE_TYPE, std::to_string(static_cast<int32_t>(rule->ruleType)))
789 ->And()
790 ->EqualTo(NET_FIREWALL_APP_ID, std::to_string(rule->appUid))
791 ->And()
792 ->BeginWrap()
793 ->EqualTo(NET_FIREWALL_DNS_PRIMARY, rule->dns.primaryDns)
794 ->Or()
795 ->EqualTo(NET_FIREWALL_DNS_STANDY, rule->dns.standbyDns)
796 ->EndWrap()
797 ->Limit(1)
798 ->EndWrap();
799 std::vector<std::string> columns;
800 auto resultSet = firewallDatabase_->Query(rdbPredicates, columns);
801 if (resultSet == nullptr) {
802 NETMGR_EXT_LOG_E("IsDnsRuleExist Query error");
803 return false;
804 }
805 int32_t rowCount = 0;
806 resultSet->GetRowCount(rowCount);
807 return rowCount > 0;
808 }
809
AddInterceptRecord(const int32_t userId,std::vector<sptr<InterceptRecord>> & records)810 int32_t NetFirewallDbHelper::AddInterceptRecord(const int32_t userId, std::vector<sptr<InterceptRecord>> &records)
811 {
812 std::lock_guard<std::mutex> guard(databaseMutex_);
813 int32_t ret = firewallDatabase_->BeginTransaction();
814 // Aging by date, record up to 8 days of data
815 std::string whereClause = { "userId = ? AND time < ?" };
816 std::vector<std::string> whereArgs = { std::to_string(userId),
817 std::to_string(records.back()->time - RECORD_MAX_SAVE_TIME) };
818 int32_t changedRows = 0;
819 ret = firewallDatabase_->Delete(INTERCEPT_RECORD_TABLE, changedRows, whereClause, whereArgs);
820
821 int64_t currentRows = 0;
822 RdbPredicates rdbPredicates(INTERCEPT_RECORD_TABLE);
823 rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
824 firewallDatabase_->Count(currentRows, rdbPredicates);
825 // Aging by number, record up to 1000 pieces of data
826 size_t size = records.size();
827 int64_t leftRows = RECORD_MAX_DATA_NUM - currentRows;
828 if (leftRows < size) {
829 std::string whereClause("id in (select id from ");
830 whereClause += INTERCEPT_RECORD_TABLE;
831 whereClause += " where userId = ? order by id limit ? )";
832 std::vector<std::string> whereArgs = { std::to_string(userId), std::to_string(size - leftRows) };
833 ret = firewallDatabase_->Delete(INTERCEPT_RECORD_TABLE, changedRows, whereClause, whereArgs);
834 }
835 // New data written to the database
836 ValuesBucket values;
837 for (size_t i = 0; i < size; i++) {
838 values.Clear();
839 values.PutInt(NET_FIREWALL_USER_ID, userId);
840 values.PutInt(NET_FIREWALL_RECORD_TIME, records[i]->time);
841 values.PutString(NET_FIREWALL_RECORD_LOCAL_IP, records[i]->localIp);
842 values.PutString(NET_FIREWALL_RECORD_REMOTE_IP, records[i]->remoteIp);
843 values.PutInt(NET_FIREWALL_RECORD_LOCAL_PORT, static_cast<int32_t>(records[i]->localPort));
844 values.PutInt(NET_FIREWALL_RECORD_REMOTE_PORT, static_cast<int32_t>(records[i]->remotePort));
845 values.PutInt(NET_FIREWALL_RECORD_PROTOCOL, static_cast<int32_t>(records[i]->protocol));
846 values.PutInt(NET_FIREWALL_RECORD_UID, records[i]->appUid);
847 values.PutString(NET_FIREWALL_DOMAIN, records[i]->domain);
848
849 ret = firewallDatabase_->Insert(values, INTERCEPT_RECORD_TABLE);
850 if (ret < FIREWALL_OK) {
851 NETMGR_EXT_LOG_E("AddInterceptRecord error: %{public}d", ret);
852 firewallDatabase_->Commit();
853 return -1;
854 }
855 }
856 return firewallDatabase_->Commit();
857 }
858
DeleteInterceptRecord(const int32_t userId)859 int32_t NetFirewallDbHelper::DeleteInterceptRecord(const int32_t userId)
860 {
861 std::lock_guard<std::mutex> guard(databaseMutex_);
862 std::string whereClause = { "userId = ?" };
863 std::vector<std::string> whereArgs = { std::to_string(userId) };
864 int32_t changedRows = 0;
865 int32_t ret = firewallDatabase_->Delete(INTERCEPT_RECORD_TABLE, changedRows, whereClause, whereArgs);
866 if (ret < FIREWALL_OK) {
867 NETMGR_EXT_LOG_E("DeleteInterceptRecord error: %{public}d", ret);
868 return -1;
869 }
870 return ret;
871 }
872
QueryInterceptRecord(const int32_t userId,const sptr<RequestParam> & requestParam,sptr<InterceptRecordPage> & info)873 int32_t NetFirewallDbHelper::QueryInterceptRecord(const int32_t userId, const sptr<RequestParam> &requestParam,
874 sptr<InterceptRecordPage> &info)
875 {
876 std::lock_guard<std::mutex> guard(databaseMutex_);
877 int64_t rowCount = 0;
878 RdbPredicates rdbPredicates(INTERCEPT_RECORD_TABLE);
879 rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId))->EndWrap();
880 firewallDatabase_->Count(rowCount, rdbPredicates);
881 info->totalPage = rowCount / requestParam->pageSize;
882 int32_t remainder = rowCount % requestParam->pageSize;
883 if (remainder > 0) {
884 info->totalPage += 1;
885 }
886 NETMGR_EXT_LOG_I("QueryInterceptRecord: userId=%{public}d page=%{public}d pageSize=%{public}d total=%{public}d",
887 userId, requestParam->page, requestParam->pageSize, info->totalPage);
888 if (info->totalPage < requestParam->page) {
889 return FIREWALL_FAILURE;
890 }
891 info->page = requestParam->page;
892 std::vector<std::string> columns;
893 rdbPredicates.Clear();
894 rdbPredicates.BeginWrap()->EqualTo(NET_FIREWALL_USER_ID, std::to_string(userId));
895 if (requestParam->orderType == NetFirewallOrderType::ORDER_ASC) {
896 rdbPredicates.OrderByAsc(NET_FIREWALL_RECORD_TIME);
897 } else {
898 rdbPredicates.OrderByDesc(NET_FIREWALL_RECORD_TIME);
899 }
900 rdbPredicates.Limit((requestParam->page - 1) * requestParam->pageSize, requestParam->pageSize)->EndWrap();
901 return QueryAndGetResult(rdbPredicates, columns, info->data);
902 }
903
FirewallIpToDbIp(const std::vector<NetFirewallIpParam> & ips,std::vector<DataBaseIp> & dbips)904 void NetFirewallDbHelper::FirewallIpToDbIp(const std::vector<NetFirewallIpParam> &ips, std::vector<DataBaseIp> &dbips)
905 {
906 dbips.clear();
907 DataBaseIp dbip;
908 for (const NetFirewallIpParam ¶m : ips) {
909 dbip.family = param.family;
910 dbip.mask = param.mask;
911 dbip.type = param.type;
912 if (dbip.family == FAMILY_IPV4) {
913 memcpy_s(&dbip.ipv4.startIp, sizeof(uint32_t), ¶m.ipv4.startIp, sizeof(uint32_t));
914 memcpy_s(&dbip.ipv4.endIp, sizeof(uint32_t), ¶m.ipv4.endIp, sizeof(uint32_t));
915 } else {
916 memcpy_s(&dbip.ipv6.startIp, sizeof(in6_addr), ¶m.ipv6.startIp, sizeof(in6_addr));
917 memcpy_s(&dbip.ipv6.endIp, sizeof(in6_addr), ¶m.ipv6.endIp, sizeof(in6_addr));
918 }
919 dbips.emplace_back(std::move(dbip));
920 }
921 }
DbIpToFirewallIp(const std::vector<DataBaseIp> & dbips,std::vector<NetFirewallIpParam> & ips)922 void NetFirewallDbHelper::DbIpToFirewallIp(const std::vector<DataBaseIp> &dbips, std::vector<NetFirewallIpParam> &ips)
923 {
924 ips.clear();
925 NetFirewallIpParam dbip;
926 for (const DataBaseIp ¶m : dbips) {
927 dbip.family = param.family;
928 dbip.mask = param.mask;
929 dbip.type = param.type;
930 if (dbip.family == FAMILY_IPV4) {
931 memcpy_s(&dbip.ipv4.startIp, sizeof(uint32_t), ¶m.ipv4.startIp, sizeof(uint32_t));
932 memcpy_s(&dbip.ipv4.endIp, sizeof(uint32_t), ¶m.ipv4.endIp, sizeof(uint32_t));
933 } else {
934 memcpy_s(&dbip.ipv6.startIp, sizeof(in6_addr), ¶m.ipv6.startIp, sizeof(in6_addr));
935 memcpy_s(&dbip.ipv6.endIp, sizeof(in6_addr), ¶m.ipv6.endIp, sizeof(in6_addr));
936 }
937 ips.emplace_back(std::move(dbip));
938 }
939 }
FirewallPortToDbPort(const std::vector<NetFirewallPortParam> & ports,std::vector<DataBasePort> & dbports)940 void NetFirewallDbHelper::FirewallPortToDbPort(const std::vector<NetFirewallPortParam> &ports,
941 std::vector<DataBasePort> &dbports)
942 {
943 dbports.clear();
944 DataBasePort dbport;
945 for (const NetFirewallPortParam ¶m : ports) {
946 dbport.startPort = param.startPort;
947 dbport.endPort = param.endPort;
948 dbports.emplace_back(std::move(dbport));
949 }
950 }
951
DbPortToFirewallPort(const std::vector<DataBasePort> & dbports,std::vector<NetFirewallPortParam> & ports)952 void NetFirewallDbHelper::DbPortToFirewallPort(const std::vector<DataBasePort> &dbports,
953 std::vector<NetFirewallPortParam> &ports)
954 {
955 ports.clear();
956 NetFirewallPortParam dbport;
957 for (const DataBasePort ¶m : dbports) {
958 dbport.startPort = param.startPort;
959 dbport.endPort = param.endPort;
960 ports.emplace_back(std::move(dbport));
961 }
962 }
963 } // namespace NetManagerStandard
964 } // namespace OHOS