1 /*
2  * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "connect_context.h"
17 #include "constant.h"
18 #include "napi_utils.h"
19 #include "netstack_common_utils.h"
20 #include "netstack_log.h"
21 #include "securec.h"
22 #include <utility>
23 
24 namespace OHOS::NetStack::Websocket {
ConnectContext(napi_env env,EventManager * manager)25 ConnectContext::ConnectContext(napi_env env, EventManager *manager) : BaseContext(env, manager) {}
ConnectContext(napi_env env,const std::shared_ptr<EventManager> & sharedManager)26 ConnectContext::ConnectContext(napi_env env, const std::shared_ptr<EventManager> &sharedManager)
27     : BaseContext(env, sharedManager)
28 {
29 }
30 
31 ConnectContext::~ConnectContext() = default;
32 
AddSlashBeforeQuery(std::string & url)33 static void AddSlashBeforeQuery(std::string &url)
34 {
35     if (url.empty()) {
36         return;
37     }
38     std::string delimiter = "://";
39     size_t posStart = url.find(delimiter);
40     if (posStart != std::string::npos) {
41         posStart += delimiter.length();
42     } else {
43         posStart = 0;
44     }
45     size_t notSlash = url.find_first_not_of('/', posStart);
46     if (notSlash != std::string::npos) {
47         posStart = notSlash;
48     }
49     auto queryPos = url.find('?', posStart);
50     if (url.find('/', posStart) > queryPos) {
51         url.insert(queryPos, 1, '/');
52     }
53 }
54 
ParseParams(napi_value * params,size_t paramsCount)55 void ConnectContext::ParseParams(napi_value *params, size_t paramsCount)
56 {
57     if (!CheckParamsType(params, paramsCount)) {
58         ParseCallback(params, paramsCount);
59         return;
60     }
61 
62     if (paramsCount == FUNCTION_PARAM_ONE) {
63         if (NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string) {
64             url = NapiUtils::GetStringFromValueUtf8(GetEnv(), params[0]);
65             AddSlashBeforeQuery(url);
66             SetParseOK(true);
67         }
68         return;
69     }
70     if (paramsCount == FUNCTION_PARAM_TWO) {
71         if (NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string) {
72             url = NapiUtils::GetStringFromValueUtf8(GetEnv(), params[0]);
73             AddSlashBeforeQuery(url);
74         }
75         if (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_function) {
76             return SetParseOK(SetCallback(params[1]) == napi_ok);
77         }
78         if (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object) {
79             ParseHeader(params[1]);
80             ParseCaPath(params[1]);
81             ParseClientCert(params[1]);
82             if (!ParseProxy(params[1]) || !ParseProtocol(params[1])) {
83                 return;
84             }
85             return SetParseOK(true);
86         }
87     }
88     if (paramsCount == FUNCTION_PARAM_THREE) {
89         ParseParamsCountThree(params);
90     }
91 }
92 
ParseCallback(napi_value const * params,size_t paramsCount)93 void ConnectContext::ParseCallback(napi_value const *params, size_t paramsCount)
94 {
95     if (paramsCount == FUNCTION_PARAM_ONE) {
96         if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_ONE - 1]) == napi_function) {
97             SetCallback(params[FUNCTION_PARAM_ONE - 1]);
98         }
99         return;
100     }
101 
102     if (paramsCount == FUNCTION_PARAM_TWO) {
103         if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_TWO - 1]) == napi_function) {
104             SetCallback(params[FUNCTION_PARAM_TWO - 1]);
105         }
106         return;
107     }
108 
109     if (paramsCount == FUNCTION_PARAM_THREE) {
110         if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function) {
111             SetCallback(params[FUNCTION_PARAM_THREE - 1]);
112         }
113         return;
114     }
115 }
116 
ParseParamsCountThree(napi_value const * params)117 void ConnectContext::ParseParamsCountThree(napi_value const *params)
118 {
119     if (NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string) {
120         url = NapiUtils::GetStringFromValueUtf8(GetEnv(), params[0]);
121         AddSlashBeforeQuery(url);
122     }
123     if (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object) {
124         ParseHeader(params[1]);
125         ParseCaPath(params[1]);
126         ParseClientCert(params[1]);
127         if (!ParseProxy(params[1]) || !ParseProtocol(params[1])) {
128             if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function) {
129                 SetCallback(params[FUNCTION_PARAM_THREE - 1]);
130                 return;
131             }
132             return;
133         }
134     }
135     if (NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function) {
136         return SetParseOK(SetCallback(params[FUNCTION_PARAM_THREE - 1]) == napi_ok);
137     }
138 }
139 
ParseHeader(napi_value optionsValue)140 void ConnectContext::ParseHeader(napi_value optionsValue)
141 {
142     if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::HEADER)) {
143         return;
144     }
145     napi_value jsHeader = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::HEADER);
146     if (NapiUtils::GetValueType(GetEnv(), jsHeader) != napi_object) {
147         return;
148     }
149     auto names = NapiUtils::GetPropertyNames(GetEnv(), jsHeader);
150     std::for_each(names.begin(), names.end(), [jsHeader, this](const std::string &name) {
151         auto value = NapiUtils::GetStringPropertyUtf8(GetEnv(), jsHeader, name);
152         if (!value.empty()) {
153             // header key ignores key but value not
154             header[CommonUtils::ToLower(name)] = value;
155         }
156     });
157 }
158 
ParseCaPath(napi_value optionsValue)159 void ConnectContext::ParseCaPath(napi_value optionsValue)
160 {
161     if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::CAPATH)) {
162         NETSTACK_LOGI("ConnectContext CAPATH not found");
163         return;
164     }
165     napi_value jsCaPath = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::CAPATH);
166     if (NapiUtils::GetValueType(GetEnv(), jsCaPath) != napi_string) {
167         return;
168     }
169     caPath_ = NapiUtils::GetStringPropertyUtf8(GetEnv(), optionsValue, ContextKey::CAPATH);
170 }
171 
GetClientCert(std::string & cert,Secure::SecureChar & key,Secure::SecureChar & keyPassword)172 void ConnectContext::GetClientCert(std::string &cert, Secure::SecureChar &key, Secure::SecureChar &keyPassword)
173 {
174     cert = clientCert_;
175     key = clientKey_;
176     keyPassword = keyPassword_;
177 }
178 
SetClientCert(std::string & cert,Secure::SecureChar & key,Secure::SecureChar & keyPassword)179 void ConnectContext::SetClientCert(std::string &cert, Secure::SecureChar &key, Secure::SecureChar &keyPassword)
180 {
181     clientCert_ = cert;
182     clientKey_ = key;
183     keyPassword_ = keyPassword;
184 }
185 
ParseClientCert(napi_value optionsValue)186 void ConnectContext::ParseClientCert(napi_value optionsValue)
187 {
188     if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::CLIENT_CERT)) {
189         NETSTACK_LOGI("ConnectContext CLIENT_CERT not found");
190         return;
191     }
192     napi_value jsCert = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::CLIENT_CERT);
193     napi_valuetype type = NapiUtils::GetValueType(GetEnv(), jsCert);
194     if (type != napi_object || type == napi_undefined) {
195         return;
196     }
197     std::string certPath = NapiUtils::GetStringPropertyUtf8(GetEnv(), jsCert, ContextKey::CERT_PATH);
198     Secure::SecureChar keyPath =
199         Secure::SecureChar(NapiUtils::GetStringPropertyUtf8(GetEnv(), jsCert, ContextKey::KEY_PATH));
200     Secure::SecureChar keyPassword =
201         Secure::SecureChar(NapiUtils::GetStringPropertyUtf8(GetEnv(), jsCert, ContextKey::KEY_PASSWD));
202     SetClientCert(certPath, keyPath, keyPassword);
203 }
204 
ParseProxy(napi_value optionsValue)205 bool ConnectContext::ParseProxy(napi_value optionsValue)
206 {
207     if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::PROXY)) {
208         SetWebsocketProxyType(WebsocketProxyType::USE_SYSTEM);
209         NETSTACK_LOGD("websocket connect proxy not found, use system proxy");
210         return true;
211     }
212     napi_value websocketProxyValue = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::PROXY);
213     napi_valuetype type = NapiUtils::GetValueType(GetEnv(), websocketProxyValue);
214     if (type == napi_string) {
215         std::string proxyStr = NapiUtils::GetStringFromValueUtf8(GetEnv(), websocketProxyValue);
216         if (proxyStr == ContextKey::NOT_USE_PROXY) {
217             SetWebsocketProxyType(WebsocketProxyType::NOT_USE);
218             return true;
219         } else if (proxyStr == ContextKey::USE_SYSTEM_PROXY) {
220             SetWebsocketProxyType(WebsocketProxyType::USE_SYSTEM);
221             return true;
222         } else {
223             NETSTACK_LOGE("websocket proxy param parse failed!");
224             return false;
225         }
226     }
227     if (type != napi_object) {
228         NETSTACK_LOGE("websocket proxy param parse failed!");
229         return false;
230     }
231 
232     std::string exclusionList;
233     std::string host =
234         NapiUtils::GetStringPropertyUtf8(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_HOST);
235     int32_t port = NapiUtils::GetInt32Property(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_PORT);
236     if (NapiUtils::HasNamedProperty(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_EXCLUSION_LIST)) {
237         napi_value exclusionListValue =
238             NapiUtils::GetNamedProperty(GetEnv(), websocketProxyValue, ContextKey::WEBSOCKET_PROXY_EXCLUSION_LIST);
239         uint32_t listLength = NapiUtils::GetArrayLength(GetEnv(), exclusionListValue);
240         for (uint32_t index = 0; index < listLength; ++index) {
241             napi_value exclusionValue = NapiUtils::GetArrayElement(GetEnv(), exclusionListValue, index);
242             std::string exclusion = NapiUtils::GetStringFromValueUtf8(GetEnv(), exclusionValue);
243             if (index != 0) {
244                 exclusionList.append(ContextKey::WEBSOCKET_PROXY_EXCLUSIONS_SEPARATOR);
245             }
246             exclusionList += exclusion;
247         }
248     }
249     SetSpecifiedWebsocketProxy(host, port, exclusionList);
250     SetWebsocketProxyType(WebsocketProxyType::USE_SPECIFIED);
251     return true;
252 }
253 
ParseProtocol(napi_value optionsValue)254 bool ConnectContext::ParseProtocol(napi_value optionsValue)
255 {
256     if (!NapiUtils::HasNamedProperty(GetEnv(), optionsValue, ContextKey::PROTCOL)) {
257         NETSTACK_LOGD("websocket connect protocol not found");
258         return true;
259     }
260     napi_value jsProtocol = NapiUtils::GetNamedProperty(GetEnv(), optionsValue, ContextKey::PROTCOL);
261     if (NapiUtils::GetValueType(GetEnv(), jsProtocol) == napi_string) {
262         SetProtocol(NapiUtils::GetStringPropertyUtf8(GetEnv(), optionsValue, ContextKey::PROTCOL));
263         return true;
264     }
265     NETSTACK_LOGE("websocket connect protocol param parse failed");
266     return false;
267 }
268 
CheckParamsType(napi_value * params,size_t paramsCount)269 bool ConnectContext::CheckParamsType(napi_value *params, size_t paramsCount)
270 {
271     if (paramsCount == FUNCTION_PARAM_ONE) {
272         return NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string;
273     }
274     if (paramsCount == FUNCTION_PARAM_TWO) {
275         return NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string &&
276                (NapiUtils::GetValueType(GetEnv(), params[1]) == napi_function ||
277                 NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object);
278     }
279     if (paramsCount == FUNCTION_PARAM_THREE) {
280         return NapiUtils::GetValueType(GetEnv(), params[0]) == napi_string &&
281                NapiUtils::GetValueType(GetEnv(), params[1]) == napi_object &&
282                NapiUtils::GetValueType(GetEnv(), params[FUNCTION_PARAM_THREE - 1]) == napi_function;
283     }
284     return false;
285 }
286 
SetProtocol(std::string protocol)287 void ConnectContext::SetProtocol(std::string protocol)
288 {
289     websocketProtocol_ = std::move(protocol);
290 }
291 
GetProtocol() const292 std::string ConnectContext::GetProtocol() const
293 {
294     return websocketProtocol_;
295 }
296 
SetWebsocketProxyType(WebsocketProxyType type)297 void ConnectContext::SetWebsocketProxyType(WebsocketProxyType type)
298 {
299     usingWebsocketProxyType_ = type;
300 }
301 
GetUsingWebsocketProxyType() const302 WebsocketProxyType ConnectContext::GetUsingWebsocketProxyType() const
303 {
304     return usingWebsocketProxyType_;
305 }
306 
SetSpecifiedWebsocketProxy(const std::string & host,int32_t port,const std::string & exclusionList)307 void ConnectContext::SetSpecifiedWebsocketProxy(const std::string &host, int32_t port, const std::string &exclusionList)
308 {
309     websocketProxyHost_ = host;
310     websocketProxyPort_ = port;
311     websocketProxyExclusions_ = exclusionList;
312 }
313 
GetSpecifiedWebsocketProxy(std::string & host,uint32_t & port,std::string & exclusionList) const314 void ConnectContext::GetSpecifiedWebsocketProxy(std::string &host, uint32_t &port, std::string &exclusionList) const
315 {
316     host = websocketProxyHost_;
317     port = websocketProxyPort_;
318     exclusionList = websocketProxyExclusions_;
319 }
320 
GetErrorCode() const321 int32_t ConnectContext::GetErrorCode() const
322 {
323     if (BaseContext::IsPermissionDenied()) {
324         return PERMISSION_DENIED_CODE;
325     }
326     if (BaseContext::IsNoAllowedHost()) {
327         return WEBSOCKET_NOT_ALLOWED_HOST;
328     }
329 
330     auto err = BaseContext::GetErrorCode();
331     if (err == PARSE_ERROR_CODE) {
332         return PARSE_ERROR_CODE;
333     }
334     if (WEBSOCKET_ERR_MAP.find(err) != WEBSOCKET_ERR_MAP.end()) {
335         return err;
336     }
337     return WEBSOCKET_UNKNOWN_OTHER_ERROR;
338 }
339 
GetErrorMessage() const340 std::string ConnectContext::GetErrorMessage() const
341 {
342     if (BaseContext::IsPermissionDenied()) {
343         return PERMISSION_DENIED_MSG;
344     }
345     if (BaseContext::IsNoAllowedHost()) {
346         return WEBSOCKET_ERR_MAP.at(WEBSOCKET_NOT_ALLOWED_HOST);
347     }
348 
349     auto err = BaseContext::GetErrorCode();
350     if (err == PARSE_ERROR_CODE) {
351         return PARSE_ERROR_MSG;
352     }
353     auto it = WEBSOCKET_ERR_MAP.find(err);
354     if (it != WEBSOCKET_ERR_MAP.end()) {
355         return it->second;
356     }
357     it = WEBSOCKET_ERR_MAP.find(WEBSOCKET_UNKNOWN_OTHER_ERROR);
358     if (it != WEBSOCKET_ERR_MAP.end()) {
359         return it->second;
360     }
361     return {};
362 }
363 
IsAtomicService() const364 bool ConnectContext::IsAtomicService() const
365 {
366     return isAtomicService_;
367 }
368 
SetAtomicService(bool isAtomicService)369 void ConnectContext::SetAtomicService(bool isAtomicService)
370 {
371     isAtomicService_ = isAtomicService;
372 }
373 
SetBundleName(const std::string & bundleName)374 void ConnectContext::SetBundleName(const std::string &bundleName)
375 {
376     bundleName_ = bundleName;
377 }
378 
GetBundleName() const379 std::string ConnectContext::GetBundleName() const
380 {
381     return bundleName_;
382 }
383 } // namespace OHOS::NetStack::Websocket
384