1 /*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "connect_context.h"
17 #include "constant.h"
18 #include "napi_utils.h"
19 #include "netstack_common_utils.h"
20 #include "netstack_log.h"
21 #include "securec.h"
22 #include <utility>
23
24 namespace OHOS::NetStack::Websocket {
ConnectContext(napi_env env,EventManager * manager)25 ConnectContext::ConnectContext(napi_env env, EventManager *manager) : BaseContext(env, manager) {}
ConnectContext(napi_env env,const std::shared_ptr<EventManager> & sharedManager)26 ConnectContext::ConnectContext(napi_env env, const std::shared_ptr<EventManager> &sharedManager)
27 : BaseContext(env, sharedManager)
28 {
29 }
30
31 ConnectContext::~ConnectContext() = default;
32
AddSlashBeforeQuery(std::string & url)33 static void AddSlashBeforeQuery(std::string &url)
34 {
35 if (url.empty()) {
36 return;
37 }
38 std::string delimiter = "://";
39 size_t posStart = url.find(delimiter);
40 if (posStart != std::string::npos) {
41 posStart += delimiter.length();
42 } else {
43 posStart = 0;
44 }
45 size_t notSlash = url.find_first_not_of('/', posStart);
46 if (notSlash != std::string::npos) {
47 posStart = notSlash;
48 }
49 auto queryPos = url.find('?', posStart);
50 if (url.find('/', posStart) > queryPos) {
51 url.insert(queryPos, 1, '/');
52 }
53 }
54
ParseParams(napi_value * params,size_t paramsCount)55 void ConnectContext::ParseParams(napi_value *params, size_t paramsCount)
56 {
57 if (!CheckParamsType(params, paramsCount)) {
58 ParseCallback(params, paramsCount);
59 return;
60 }
61
62 if (paramsCount == FUNCTION_PARAM_ONE) {
63 if (NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string) {
64 url = NapiUtils::GetStringFromValueUtf8(GetEnv(), params[0]);
65 AddSlashBeforeQuery(url);
66 SetParseOK(true);
67 }
68 return;
69 }
70 if (paramsCount == FUNCTION_PARAM_TWO) {
71 if (NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string) {
72 url = NapiUtils::GetStringFromValueUtf8(GetEnv(), params[0]);
73 AddSlashBeforeQuery(url);
74 }
75 if (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_function) {
76 return SetParseOK(SetCallback(params[1]) == napi_ok);
77 }
78 if (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object) {
79 ParseHeader(params[1]);
80 ParseCaPath(params[1]);
81 ParseClientCert(params[1]);
82 if (!ParseProxy(params[1]) || !ParseProtocol(params[1])) {
83 return;
84 }
85 return SetParseOK(true);
86 }
87 }
88 if (paramsCount == FUNCTION_PARAM_THREE) {
89 ParseParamsCountThree(params);
90 }
91 }
92
ParseCallback(napi_value const * params,size_t paramsCount)93 void ConnectContext::ParseCallback(napi_value const *params, size_t paramsCount)
94 {
95 if (paramsCount == FUNCTION_PARAM_ONE) {
96 if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_ONE - 1]) == napi_function) {
97 SetCallback(params[FUNCTION_PARAM_ONE - 1]);
98 }
99 return;
100 }
101
102 if (paramsCount == FUNCTION_PARAM_TWO) {
103 if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_TWO - 1]) == napi_function) {
104 SetCallback(params[FUNCTION_PARAM_TWO - 1]);
105 }
106 return;
107 }
108
109 if (paramsCount == FUNCTION_PARAM_THREE) {
110 if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function) {
111 SetCallback(params[FUNCTION_PARAM_THREE - 1]);
112 }
113 return;
114 }
115 }
116
ParseParamsCountThree(napi_value const * params)117 void ConnectContext::ParseParamsCountThree(napi_value const *params)
118 {
119 if (NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string) {
120 url = NapiUtils::GetStringFromValueUtf8(GetEnv(), params[0]);
121 AddSlashBeforeQuery(url);
122 }
123 if (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object) {
124 ParseHeader(params[1]);
125 ParseCaPath(params[1]);
126 ParseClientCert(params[1]);
127 if (!ParseProxy(params[1]) || !ParseProtocol(params[1])) {
128 if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function) {
129 SetCallback(params[FUNCTION_PARAM_THREE - 1]);
130 return;
131 }
132 return;
133 }
134 }
135 if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function) {
136 return SetParseOK(SetCallback(params[FUNCTION_PARAM_THREE - 1]) == napi_ok);
137 }
138 }
139
ParseHeader(napi_value optionsValue)140 void ConnectContext::ParseHeader(napi_value optionsValue)
141 {
142 if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::HEADER)) {
143 return;
144 }
145 napi_value jsHeader = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::HEADER);
146 if (NapiUtils::GetValueType(GetEnv(), jsHeader) != napi_object) {
147 return;
148 }
149 auto names = NapiUtils::GetPropertyNames(GetEnv(), jsHeader);
150 std::for_each(names.begin(), names.end(), [jsHeader, this](const std::string &name) {
151 auto value = NapiUtils::GetStringPropertyUtf8(GetEnv(), jsHeader, name);
152 if (!value.empty()) {
153 // header key ignores key but value not
154 header[CommonUtils::ToLower(name)] = value;
155 }
156 });
157 }
158
ParseCaPath(napi_value optionsValue)159 void ConnectContext::ParseCaPath(napi_value optionsValue)
160 {
161 if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::CAPATH)) {
162 NETSTACK_LOGI("ConnectContext CAPATH not found");
163 return;
164 }
165 napi_value jsCaPath = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::CAPATH);
166 if (NapiUtils::GetValueType(GetEnv(), jsCaPath) != napi_string) {
167 return;
168 }
169 caPath_ = NapiUtils::GetStringPropertyUtf8(GetEnv(), optionsValue, ContextKey::CAPATH);
170 }
171
GetClientCert(std::string & cert,Secure::SecureChar & key,Secure::SecureChar & keyPassword)172 void ConnectContext::GetClientCert(std::string &cert, Secure::SecureChar &key, Secure::SecureChar &keyPassword)
173 {
174 cert = clientCert_;
175 key = clientKey_;
176 keyPassword = keyPassword_;
177 }
178
SetClientCert(std::string & cert,Secure::SecureChar & key,Secure::SecureChar & keyPassword)179 void ConnectContext::SetClientCert(std::string &cert, Secure::SecureChar &key, Secure::SecureChar &keyPassword)
180 {
181 clientCert_ = cert;
182 clientKey_ = key;
183 keyPassword_ = keyPassword;
184 }
185
ParseClientCert(napi_value optionsValue)186 void ConnectContext::ParseClientCert(napi_value optionsValue)
187 {
188 if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::CLIENT_CERT)) {
189 NETSTACK_LOGI("ConnectContext CLIENT_CERT not found");
190 return;
191 }
192 napi_value jsCert = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::CLIENT_CERT);
193 napi_valuetype type = NapiUtils::GetValueType(GetEnv(), jsCert);
194 if (type != napi_object || type == napi_undefined) {
195 return;
196 }
197 std::string certPath = NapiUtils::GetStringPropertyUtf8(GetEnv(), jsCert, ContextKey::CERT_PATH);
198 Secure::SecureChar keyPath =
199 Secure::SecureChar(NapiUtils::GetStringPropertyUtf8(GetEnv(), jsCert, ContextKey::KEY_PATH));
200 Secure::SecureChar keyPassword =
201 Secure::SecureChar(NapiUtils::GetStringPropertyUtf8(GetEnv(), jsCert, ContextKey::KEY_PASSWD));
202 SetClientCert(certPath, keyPath, keyPassword);
203 }
204
ParseProxy(napi_value optionsValue)205 bool ConnectContext::ParseProxy(napi_value optionsValue)
206 {
207 if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::PROXY)) {
208 SetWebsocketProxyType(WebsocketProxyType::USE_SYSTEM);
209 NETSTACK_LOGD("websocket connect proxy not found, use system proxy");
210 return true;
211 }
212 napi_value websocketProxyValue = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::PROXY);
213 napi_valuetype type = NapiUtils::GetValueType(GetEnv(), websocketProxyValue);
214 if (type == napi_string) {
215 std::string proxyStr = NapiUtils::GetStringFromValueUtf8(GetEnv(), websocketProxyValue);
216 if (proxyStr == ContextKey::NOT_USE_PROXY) {
217 SetWebsocketProxyType(WebsocketProxyType::NOT_USE);
218 return true;
219 } else if (proxyStr == ContextKey::USE_SYSTEM_PROXY) {
220 SetWebsocketProxyType(WebsocketProxyType::USE_SYSTEM);
221 return true;
222 } else {
223 NETSTACK_LOGE("websocket proxy param parse failed!");
224 return false;
225 }
226 }
227 if (type != napi_object) {
228 NETSTACK_LOGE("websocket proxy param parse failed!");
229 return false;
230 }
231
232 std::string exclusionList;
233 std::string host =
234 NapiUtils::GetStringPropertyUtf8(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_HOST);
235 int32_t port = NapiUtils::GetInt32Property(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_PORT);
236 if (NapiUtils::HasNamedProperty(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_EXCLUSION_LIST)) {
237 napi_value exclusionListValue =
238 NapiUtils::GetNamedProperty(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_EXCLUSION_LIST);
239 uint32_t listLength = NapiUtils::GetArrayLength(GetEnv(), exclusionListValue);
240 for (uint32_t index = 0; index < listLength; ++index) {
241 napi_value exclusionValue = NapiUtils::GetArrayElement(GetEnv(), exclusionListValue, index);
242 std::string exclusion = NapiUtils::GetStringFromValueUtf8(GetEnv(), exclusionValue);
243 if (index != 0) {
244 exclusionList.append(ContextKey::WEBSOCKET_PROXY_EXCLUSIONS_SEPARATOR);
245 }
246 exclusionList += exclusion;
247 }
248 }
249 SetSpecifiedWebsocketProxy(host, port, exclusionList);
250 SetWebsocketProxyType(WebsocketProxyType::USE_SPECIFIED);
251 return true;
252 }
253
ParseProtocol(napi_value optionsValue)254 bool ConnectContext::ParseProtocol(napi_value optionsValue)
255 {
256 if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::PROTCOL)) {
257 NETSTACK_LOGD("websocket connect protocol not found");
258 return true;
259 }
260 napi_value jsProtocol = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::PROTCOL);
261 if (NapiUtils::GetValueType(GetEnv(), jsProtocol) == napi_string) {
262 SetProtocol(NapiUtils::GetStringPropertyUtf8(GetEnv(), optionsValue, ContextKey::PROTCOL));
263 return true;
264 }
265 NETSTACK_LOGE("websocket connect protocol param parse failed");
266 return false;
267 }
268
CheckParamsType(napi_value * params,size_t paramsCount)269 bool ConnectContext::CheckParamsType(napi_value *params, size_t paramsCount)
270 {
271 if (paramsCount == FUNCTION_PARAM_ONE) {
272 return NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string;
273 }
274 if (paramsCount == FUNCTION_PARAM_TWO) {
275 return NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string &&
276 (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_function ||
277 NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object);
278 }
279 if (paramsCount == FUNCTION_PARAM_THREE) {
280 return NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string &&
281 NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object &&
282 NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function;
283 }
284 return false;
285 }
286
SetProtocol(std::string protocol)287 void ConnectContext::SetProtocol(std::string protocol)
288 {
289 websocketProtocol_ = std::move(protocol);
290 }
291
GetProtocol() const292 std::string ConnectContext::GetProtocol() const
293 {
294 return websocketProtocol_;
295 }
296
SetWebsocketProxyType(WebsocketProxyType type)297 void ConnectContext::SetWebsocketProxyType(WebsocketProxyType type)
298 {
299 usingWebsocketProxyType_ = type;
300 }
301
GetUsingWebsocketProxyType() const302 WebsocketProxyType ConnectContext::GetUsingWebsocketProxyType() const
303 {
304 return usingWebsocketProxyType_;
305 }
306
SetSpecifiedWebsocketProxy(const std::string & host,int32_t port,const std::string & exclusionList)307 void ConnectContext::SetSpecifiedWebsocketProxy(const std::string &host, int32_t port, const std::string &exclusionList)
308 {
309 websocketProxyHost_ = host;
310 websocketProxyPort_ = port;
311 websocketProxyExclusions_ = exclusionList;
312 }
313
GetSpecifiedWebsocketProxy(std::string & host,uint32_t & port,std::string & exclusionList) const314 void ConnectContext::GetSpecifiedWebsocketProxy(std::string &host, uint32_t &port, std::string &exclusionList) const
315 {
316 host = websocketProxyHost_;
317 port = websocketProxyPort_;
318 exclusionList = websocketProxyExclusions_;
319 }
320
GetErrorCode() const321 int32_t ConnectContext::GetErrorCode() const
322 {
323 if (BaseContext::IsPermissionDenied()) {
324 return PERMISSION_DENIED_CODE;
325 }
326 if (BaseContext::IsNoAllowedHost()) {
327 return WEBSOCKET_NOT_ALLOWED_HOST;
328 }
329
330 auto err = BaseContext::GetErrorCode();
331 if (err == PARSE_ERROR_CODE) {
332 return PARSE_ERROR_CODE;
333 }
334 if (WEBSOCKET_ERR_MAP.find(err) != WEBSOCKET_ERR_MAP.end()) {
335 return err;
336 }
337 return WEBSOCKET_UNKNOWN_OTHER_ERROR;
338 }
339
GetErrorMessage() const340 std::string ConnectContext::GetErrorMessage() const
341 {
342 if (BaseContext::IsPermissionDenied()) {
343 return PERMISSION_DENIED_MSG;
344 }
345 if (BaseContext::IsNoAllowedHost()) {
346 return WEBSOCKET_ERR_MAP.at(WEBSOCKET_NOT_ALLOWED_HOST);
347 }
348
349 auto err = BaseContext::GetErrorCode();
350 if (err == PARSE_ERROR_CODE) {
351 return PARSE_ERROR_MSG;
352 }
353 auto it = WEBSOCKET_ERR_MAP.find(err);
354 if (it != WEBSOCKET_ERR_MAP.end()) {
355 return it->second;
356 }
357 it = WEBSOCKET_ERR_MAP.find(WEBSOCKET_UNKNOWN_OTHER_ERROR);
358 if (it != WEBSOCKET_ERR_MAP.end()) {
359 return it->second;
360 }
361 return {};
362 }
363
IsAtomicService() const364 bool ConnectContext::IsAtomicService() const
365 {
366 return isAtomicService_;
367 }
368
SetAtomicService(bool isAtomicService)369 void ConnectContext::SetAtomicService(bool isAtomicService)
370 {
371 isAtomicService_ = isAtomicService;
372 }
373
SetBundleName(const std::string & bundleName)374 void ConnectContext::SetBundleName(const std::string &bundleName)
375 {
376 bundleName_ = bundleName;
377 }
378
GetBundleName() const379 std::string ConnectContext::GetBundleName() const
380 {
381 return bundleName_;
382 }
383 } // namespace OHOS::NetStack::Websocket
384