1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "sg_classify_client.h"
17
18 #include <future>
19
20 #include "iremote_broker.h"
21 #include "iservice_registry.h"
22 #include "securec.h"
23
24 #include "risk_analysis_manager_callback_service.h"
25 #include "risk_analysis_manager_proxy.h"
26 #include "security_guard_define.h"
27 #include "security_guard_log.h"
28
29 namespace {
30 constexpr int32_t TIMEOUT_REPLY = 500;
31 }
32
33 using namespace OHOS;
34 using namespace OHOS::Security::SecurityGuard;
35
36 static std::mutex g_mutex;
37
RequestSecurityModelResult(const std::string & devId,uint32_t modelId,const std::string & param,ResultCallback callback)38 static int32_t RequestSecurityModelResult(const std::string &devId, uint32_t modelId,
39 const std::string ¶m, ResultCallback callback)
40 {
41 auto registry = SystemAbilityManagerClient::GetInstance().GetSystemAbilityManager();
42 if (registry == nullptr) {
43 SGLOGE("GetSystemAbilityManager error");
44 return NULL_OBJECT;
45 }
46
47 auto object = registry->GetSystemAbility(RISK_ANALYSIS_MANAGER_SA_ID);
48 auto proxy = iface_cast<RiskAnalysisManagerProxy>(object);
49 if (proxy == nullptr) {
50 SGLOGE("proxy is null");
51 return NULL_OBJECT;
52 }
53
54 sptr<RiskAnalysisManagerCallbackService> stub = new (std::nothrow) RiskAnalysisManagerCallbackService(callback);
55 if (stub == nullptr) {
56 SGLOGE("stub is null");
57 return NULL_OBJECT;
58 }
59 int32_t ret = proxy->RequestSecurityModelResult(devId, modelId, param, stub);
60 SGLOGI("RequestSecurityModelResult result, ret=%{public}d", ret);
61 return ret;
62 }
63
64 namespace OHOS::Security::SecurityGuard {
RequestSecurityModelResultSync(const std::string & devId,uint32_t modelId,const std::string & param,SecurityModelResult & result)65 int32_t RequestSecurityModelResultSync(const std::string &devId, uint32_t modelId,
66 const std::string ¶m, SecurityModelResult &result)
67 {
68 if (devId.length() >= DEVICE_ID_MAX_LEN) {
69 return BAD_PARAM;
70 }
71 std::unique_lock<std::mutex> lock(g_mutex);
72 auto promise = std::make_shared<std::promise<SecurityModelResult>>();
73 auto future = promise->get_future();
74 auto func = [promise, param] (const std::string &devId, uint32_t modelId,
75 const std::string &result) mutable -> int32_t {
76 SecurityModelResult modelResult = {
77 .devId = devId,
78 .modelId = modelId,
79 .param = param,
80 .result = result
81 };
82 promise->set_value(modelResult);
83 return SUCCESS;
84 };
85
86 int32_t code = RequestSecurityModelResult(devId, modelId, param, func);
87 if (code != SUCCESS) {
88 SGLOGE("RequestSecurityModelResult error, code=%{public}d", code);
89 return code;
90 }
91 std::chrono::milliseconds span(TIMEOUT_REPLY);
92 if (future.wait_for(span) == std::future_status::timeout) {
93 SGLOGE("wait timeout");
94 return TIME_OUT;
95 }
96 result = future.get();
97 return SUCCESS;
98 }
99
RequestSecurityModelResultAsync(const std::string & devId,uint32_t modelId,const std::string & param,SecurityGuardRiskCallback callback)100 int32_t RequestSecurityModelResultAsync(const std::string &devId, uint32_t modelId,
101 const std::string ¶m, SecurityGuardRiskCallback callback)
102 {
103 if (devId.length() >= DEVICE_ID_MAX_LEN) {
104 return BAD_PARAM;
105 }
106 std::unique_lock<std::mutex> lock(g_mutex);
107 auto func = [callback, param] (const std::string &devId,
108 uint32_t modelId, const std::string &result) -> int32_t {
109 callback(SecurityModelResult{devId, modelId, param, result});
110 return SUCCESS;
111 };
112
113 return RequestSecurityModelResult(devId, modelId, param, func);
114 }
115 }
116
117 #ifdef __cplusplus
118 extern "C" {
119 #endif
120
FillingRequestResult(const OHOS::Security::SecurityGuard::SecurityModelResult & cppResult,::SecurityModelResult * result)121 static int32_t FillingRequestResult(const OHOS::Security::SecurityGuard::SecurityModelResult &cppResult,
122 ::SecurityModelResult *result)
123 {
124 if (cppResult.devId.length() >= DEVICE_ID_MAX_LEN || cppResult.result.length() >= RESULT_MAX_LEN) {
125 return BAD_PARAM;
126 }
127
128 result->modelId = cppResult.modelId;
129 errno_t rc = memcpy_s(result->devId.identity, DEVICE_ID_MAX_LEN, cppResult.devId.c_str(), cppResult.devId.length());
130 if (rc != EOK) {
131 return NULL_OBJECT;
132 }
133 result->devId.length = cppResult.devId.length();
134
135 rc = memcpy_s(result->result, RESULT_MAX_LEN, cppResult.result.c_str(), cppResult.result.length());
136 if (rc != EOK) {
137 return NULL_OBJECT;
138 }
139 result->resultLen = cppResult.result.length();
140
141 SGLOGD("modelId=%{public}u, result=%{public}s", cppResult.modelId, cppResult.result.c_str());
142 return SUCCESS;
143 }
144
CovertDevId(const DeviceIdentify * devId)145 static std::string CovertDevId(const DeviceIdentify *devId)
146 {
147 std::vector<char> id(DEVICE_ID_MAX_LEN, '\0');
148 std::copy(&devId->identity[0], &devId->identity[DEVICE_ID_MAX_LEN - 1], id.begin());
149 return std::string{id.data()};
150 }
151
RequestSecurityModelResultSync(const DeviceIdentify * devId,uint32_t modelId,::SecurityModelResult * result)152 int32_t RequestSecurityModelResultSync(const DeviceIdentify *devId, uint32_t modelId, ::SecurityModelResult *result)
153 {
154 if (devId == nullptr || result == nullptr || devId->length >= DEVICE_ID_MAX_LEN) {
155 return BAD_PARAM;
156 }
157 OHOS::Security::SecurityGuard::SecurityModelResult tmp;
158 int32_t ret = OHOS::Security::SecurityGuard::RequestSecurityModelResultSync(CovertDevId(devId), modelId, "", tmp);
159 FillingRequestResult(tmp, result);
160 return ret;
161 }
162
RequestSecurityModelResultAsync(const DeviceIdentify * devId,uint32_t modelId,::SecurityGuardRiskCallback callback)163 int32_t RequestSecurityModelResultAsync(const DeviceIdentify *devId, uint32_t modelId,
164 ::SecurityGuardRiskCallback callback)
165 {
166 if (devId == nullptr || devId->length >= DEVICE_ID_MAX_LEN) {
167 return BAD_PARAM;
168 }
169 auto cppCallBack = [callback](const OHOS::Security::SecurityGuard::SecurityModelResult &tmp) {
170 ::SecurityModelResult result{};
171 FillingRequestResult(tmp, &result);
172 callback(&result);
173 };
174 return OHOS::Security::SecurityGuard::RequestSecurityModelResultAsync(CovertDevId(devId), modelId, "", cppCallBack);
175 }
176
177 #ifdef __cplusplus
178 }
179 #endif
180