/* * Copyright (c) 2023 Huawei Device Co., Ltd. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include "netstack_log.h" #include "websocket_client_innerapi.h" static constexpr const char *PATH_START = "/"; static constexpr const char *NAME_END = ":"; static constexpr const char *STATUS_LINE_SEP = " "; static constexpr const size_t STATUS_LINE_ELEM_NUM = 2; static constexpr const char *PREFIX_HTTPS = "https"; static constexpr const char *PREFIX_WSS = "wss"; static constexpr const int MAX_URI_LENGTH = 1024; static constexpr const int MAX_HDR_LENGTH = 1024; static constexpr const int MAX_HEADER_LENGTH = 8192; static constexpr const size_t MAX_DATA_LENGTH = 4 * 1024 * 1024; static constexpr const int FD_LIMIT_PER_THREAD = 1 + 1 + 1; static constexpr const int CLOSE_RESULT_FROM_SERVER_CODE = 1001; static constexpr const int CLOSE_RESULT_FROM_CLIENT_CODE = 1000; static constexpr const char *LINK_DOWN = "The link is down"; static constexpr const char *CLOSE_REASON_FORM_SERVER = "websocket close from server"; static constexpr const int FUNCTION_PARAM_TWO = 2; static constexpr const char *WEBSOCKET_CLIENT_THREAD_RUN = "OS_NET_WSCli"; static std::atomic g_clientID(0); namespace OHOS::NetStack::WebSocketClient { static const lws_retry_bo_t RETRY = { .secs_since_valid_ping = 0, /* force PINGs after secs idle */ .secs_since_valid_hangup = 10, /* hangup after secs idle */ .jitter_percent = 20, }; WebSocketClient::WebSocketClient() { clientContext = new ClientContext(); clientContext->SetClientId(++g_clientID); } WebSocketClient::~WebSocketClient() { delete clientContext; clientContext = nullptr; } ClientContext *WebSocketClient::GetClientContext() const { return clientContext; } void RunService(WebSocketClient *Client) { if (Client->GetClientContext()->GetContext() == nullptr) { return; } while (!Client->GetClientContext()->IsThreadStop()) { lws_service(Client->GetClientContext()->GetContext(), 0); } } int HttpDummy(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { int ret = lws_callback_http_dummy(wsi, reason, user, in, len); return ret; } struct CallbackDispatcher { lws_callback_reasons reason; int (*callback)(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len); }; int LwsCallbackClientAppendHandshakeHeader(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); if (client->GetClientContext() == nullptr) { NETSTACK_LOGE("Callback ClientContext is nullptr"); return -1; } NETSTACK_LOGD("ClientId:%{public}d, Lws Callback AppendHandshakeHeader,", client->GetClientContext()->GetClientId()); auto payload = reinterpret_cast(in); if (payload == nullptr || (*payload) == nullptr || len == 0) { return -1; } auto payloadEnd = (*payload) + len; for (const auto &pair : client->GetClientContext()->header) { std::string name = pair.first + NAME_END; if (lws_add_http_header_by_name(wsi, reinterpret_cast(name.c_str()), reinterpret_cast(pair.second.c_str()), static_cast(strlen(pair.second.c_str())), payload, payloadEnd)) { return -1; } } return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackWsPeerInitiatedClose(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); if (client->GetClientContext() == nullptr) { NETSTACK_LOGE("Lws Callback ClientContext is nullptr"); return -1; } NETSTACK_LOGD("ClientId:%{public}d,Callback WsPeerInitiatedClose", client->GetClientContext()->GetClientId()); if (in == nullptr || len < sizeof(uint16_t)) { NETSTACK_LOGE("Lws Callback WsPeerInitiatedClose"); client->GetClientContext()->Close(LWS_CLOSE_STATUS_NORMAL, ""); return HttpDummy(wsi, reason, user, in, len); } uint16_t closeStatus = ntohs(*reinterpret_cast(in)); std::string closeReason; closeReason.append(reinterpret_cast(in) + sizeof(uint16_t), len - sizeof(uint16_t)); client->GetClientContext()->Close(static_cast(closeStatus), closeReason); return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackClientWritable(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); if (client->GetClientContext() == nullptr) { NETSTACK_LOGE("Lws Callback ClientContext is nullptr"); return -1; } NETSTACK_LOGD("ClientId:%{public}d,Callback CallbackClientWritable,", client->GetClientContext()->GetClientId()); if (client->GetClientContext()->IsClosed()) { NETSTACK_LOGD("ClientId:%{public}d,Callback ClientWritable need to close", client->GetClientContext()->GetClientId()); lws_close_reason( wsi, client->GetClientContext()->closeStatus, reinterpret_cast(const_cast(client->GetClientContext()->closeReason.c_str())), strlen(client->GetClientContext()->closeReason.c_str())); // here do not emit error, because we close it return -1; } SendData sendData = client->GetClientContext()->Pop(); if (sendData.data == nullptr || sendData.length == 0) { return HttpDummy(wsi, reason, user, in, len); } const char *message = sendData.data; size_t messageLen = sendData.length; auto buffer = std::make_unique(LWS_PRE + messageLen); if (buffer == nullptr) { return -1; } int result = memcpy_s(buffer.get() + LWS_PRE, LWS_PRE + messageLen, message, messageLen); if (result != 0) { return -1; } free(sendData.data); int bytesSent = lws_write(wsi, buffer.get() + LWS_PRE, messageLen, sendData.protocol); NETSTACK_LOGD("ClientId:%{public}d,Client Writable send data length = %{public}d", client->GetClientContext()->GetClientId(), bytesSent); return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackClientConnectionError(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); NETSTACK_LOGE("ClientId:%{public}d,Callback ClientConnectionError", client->GetClientContext()->GetClientId()); std::string buf; char *data = static_cast(in); buf.assign(data, len); ErrorResult errorResult; errorResult.errorCode = WebSocketErrorCode::WEBSOCKET_CONNECTION_ERROR; errorResult.errorMessage = data; client->onErrorCallback_(client, errorResult); return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackClientReceive(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); NETSTACK_LOGD("ClientId:%{public}d,Callback ClientReceive", client->GetClientContext()->GetClientId()); std::string buf; char *data = static_cast(in); buf.assign(data, len); client->onMessageCallback_(client, data, len); return HttpDummy(wsi, reason, user, in, len); } std::vector Split(const std::string &str, const std::string &sep, size_t size) { std::string s = str; std::vector res; while (!s.empty()) { if (res.size() + 1 == size) { res.emplace_back(s); break; } auto pos = s.find(sep); if (pos == std::string::npos) { res.emplace_back(s); break; } res.emplace_back(s.substr(0, pos)); s = s.substr(pos + sep.size()); } return res; } int LwsCallbackClientFilterPreEstablish(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); if (client->GetClientContext() == nullptr) { NETSTACK_LOGE("Callback ClientContext is nullptr"); return -1; } client->GetClientContext()->openStatus = lws_http_client_http_response(wsi); NETSTACK_LOGD("ClientId:%{public}d, libwebsockets Callback ClientFilterPreEstablish openStatus = %{public}d", client->GetClientContext()->GetClientId(), client->GetClientContext()->openStatus); char statusLine[MAX_HDR_LENGTH] = {0}; if (lws_hdr_copy(wsi, statusLine, MAX_HDR_LENGTH, WSI_TOKEN_HTTP) < 0 || strlen(statusLine) == 0) { return HttpDummy(wsi, reason, user, in, len); } auto vec = Split(statusLine, STATUS_LINE_SEP, STATUS_LINE_ELEM_NUM); if (vec.size() >= FUNCTION_PARAM_TWO) { client->GetClientContext()->openMessage = vec[1]; } return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackClientEstablished(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); if (client->GetClientContext() == nullptr) { NETSTACK_LOGE("libwebsockets Callback ClientContext is nullptr"); return -1; } NETSTACK_LOGI("ClientId:%{public}d,Callback ClientEstablished", client->GetClientContext()->GetClientId()); OpenResult openResult; openResult.status = client->GetClientContext()->openStatus; openResult.message = client->GetClientContext()->openMessage.c_str(); client->onOpenCallback_(client, openResult); return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackClientClosed(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); if (client->GetClientContext() == nullptr) { NETSTACK_LOGE("Callback ClientContext is nullptr"); return -1; } NETSTACK_LOGI("ClientId:%{public}d,Callback ClientClosed", client->GetClientContext()->GetClientId()); std::string buf; char *data = static_cast(in); buf.assign(data, len); CloseResult closeResult; closeResult.code = CLOSE_RESULT_FROM_SERVER_CODE; closeResult.reason = CLOSE_REASON_FORM_SERVER; client->onCloseCallback_(client, closeResult); client->GetClientContext()->SetThreadStop(true); if ((client->GetClientContext()->closeReason).empty()) { client->GetClientContext()->Close(client->GetClientContext()->closeStatus, LINK_DOWN); } return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackWsiDestroy(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { WebSocketClient *client = static_cast(user); if (client->GetClientContext() == nullptr) { NETSTACK_LOGE("Callback ClientContext is nullptr"); return -1; } NETSTACK_LOGI("Lws Callback LwsCallbackWsiDestroy"); return HttpDummy(wsi, reason, user, in, len); } int LwsCallbackProtocolDestroy(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { NETSTACK_LOGI("Lws Callback ProtocolDestroy"); return HttpDummy(wsi, reason, user, in, len); } int LwsCallback(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { constexpr CallbackDispatcher dispatchers[] = { {LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER, LwsCallbackClientAppendHandshakeHeader}, {LWS_CALLBACK_WS_PEER_INITIATED_CLOSE, LwsCallbackWsPeerInitiatedClose}, {LWS_CALLBACK_CLIENT_WRITEABLE, LwsCallbackClientWritable}, {LWS_CALLBACK_CLIENT_CONNECTION_ERROR, LwsCallbackClientConnectionError}, {LWS_CALLBACK_CLIENT_RECEIVE, LwsCallbackClientReceive}, {LWS_CALLBACK_CLIENT_FILTER_PRE_ESTABLISH, LwsCallbackClientFilterPreEstablish}, {LWS_CALLBACK_CLIENT_ESTABLISHED, LwsCallbackClientEstablished}, {LWS_CALLBACK_CLIENT_CLOSED, LwsCallbackClientClosed}, {LWS_CALLBACK_WSI_DESTROY, LwsCallbackWsiDestroy}, {LWS_CALLBACK_PROTOCOL_DESTROY, LwsCallbackProtocolDestroy}, }; auto it = std::find_if(std::begin(dispatchers), std::end(dispatchers), [&reason](const CallbackDispatcher &dispatcher) { return dispatcher.reason == reason; }); if (it != std::end(dispatchers)) { return it->callback(wsi, reason, user, in, len); } return HttpDummy(wsi, reason, user, in, len); } static struct lws_protocols protocols[] = {{"lws-minimal-client1", LwsCallback, 0, 0, 0, NULL, 0}, LWS_PROTOCOL_LIST_TERM}; static void FillContextInfo(lws_context_creation_info &info) { info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT; info.port = CONTEXT_PORT_NO_LISTEN; info.protocols = protocols; info.fd_limit_per_thread = FD_LIMIT_PER_THREAD; } bool ParseUrl(const std::string url, char *prefix, char *address, char *path, int *port) { char uri[MAX_URI_LENGTH] = {0}; if (strcpy_s(uri, MAX_URI_LENGTH, url.c_str()) < 0) { NETSTACK_LOGE("strcpy_s failed"); return false; } const char *tempPrefix = nullptr; const char *tempAddress = nullptr; const char *tempPath = nullptr; (void)lws_parse_uri(uri, &tempPrefix, &tempAddress, port, &tempPath); if (strcpy_s(prefix, MAX_URI_LENGTH, tempPrefix) < 0) { NETSTACK_LOGE("strcpy_s failed"); return false; } if (strcpy_s(address, MAX_URI_LENGTH, tempAddress) < 0) { NETSTACK_LOGE("strcpy_s failed"); return false; } if (strcpy_s(path, MAX_URI_LENGTH, tempPath) < 0) { NETSTACK_LOGE("strcpy_s failed"); return false; } return true; } int CreatConnectInfo(const std::string url, lws_context *lwsContext, WebSocketClient *client) { lws_client_connect_info connectInfo = {}; char prefix[MAX_URI_LENGTH] = {0}; char address[MAX_URI_LENGTH] = {0}; char pathWithoutStart[MAX_URI_LENGTH] = {0}; int port = 0; if (!ParseUrl(url, prefix, address, pathWithoutStart, &port)) { return WebSocketErrorCode::WEBSOCKET_CONNECTION_PARSEURL_ERROR; } std::string path = PATH_START + std::string(pathWithoutStart); connectInfo.context = lwsContext; connectInfo.address = address; connectInfo.port = port; connectInfo.path = path.c_str(); connectInfo.host = address; connectInfo.origin = address; connectInfo.local_protocol_name = "lws-minimal-client1"; connectInfo.retry_and_idle_policy = &RETRY; if (strcmp(prefix, PREFIX_HTTPS) == 0 || strcmp(prefix, PREFIX_WSS) == 0) { connectInfo.ssl_connection = LCCSCF_USE_SSL | LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK | LCCSCF_ALLOW_INSECURE | LCCSCF_ALLOW_SELFSIGNED; } lws *wsi = nullptr; connectInfo.pwsi = &wsi; connectInfo.userdata = client; if (lws_client_connect_via_info(&connectInfo) == nullptr) { NETSTACK_LOGE("Connect lws_context_destroy"); return WebSocketErrorCode::WEBSOCKET_CONNECTION_TO_SERVER_FAIL; } return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } int WebSocketClient::Connect(std::string url, struct OpenOptions options) { NETSTACK_LOGI("ClientId:%{public}d, Connect start", this->GetClientContext()->GetClientId()); if (!options.headers.empty()) { if (options.headers.size() > MAX_HEADER_LENGTH) { return WebSocketErrorCode::WEBSOCKET_ERROR_NO_HEADR_EXCEEDS; } for (const auto &item : options.headers) { const std::string &key = item.first; const std::string &value = item.second; this->GetClientContext()->header[key] = value; } } lws_context_creation_info info = {}; FillContextInfo(info); lws_context *lwsContext = lws_create_context(&info); if (lwsContext == nullptr) { return WebSocketErrorCode::WEBSOCKET_CONNECTION_NO_MEMOERY; } this->GetClientContext()->SetContext(lwsContext); int ret = CreatConnectInfo(url, lwsContext, this); if (ret != WEBSOCKET_NONE_ERR) { NETSTACK_LOGE("websocket CreatConnectInfo error"); GetClientContext()->SetContext(nullptr); lws_context_destroy(lwsContext); return ret; } std::thread serviceThread(RunService, this); #if defined(MAC_PLATFORM) || defined(IOS_PLATFORM) pthread_setname_np(WEBSOCKET_CLIENT_THREAD_RUN); #else pthread_setname_np(serviceThread.native_handle(), WEBSOCKET_CLIENT_THREAD_RUN); #endif serviceThread.detach(); return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } int WebSocketClient::Send(char *data, size_t length) { if (data == nullptr) { return WebSocketErrorCode::WEBSOCKET_SEND_DATA_NULL; } if (length == 0) { return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } if (length > MAX_DATA_LENGTH) { return WebSocketErrorCode::WEBSOCKET_DATA_LENGTH_EXCEEDS; } if (this->GetClientContext() == nullptr) { return WebSocketErrorCode::WEBSOCKET_ERROR_NO_CLIENTCONTEX; } lws_write_protocol protocol = (strlen(data) == length) ? LWS_WRITE_TEXT : LWS_WRITE_BINARY; auto dataCopy = reinterpret_cast(malloc(length)); if (dataCopy == nullptr) { NETSTACK_LOGE("webSocketClient malloc error"); return WEBSOCKET_SEND_NO_MEMOERY_ERROR; } else if (memcpy_s(dataCopy, length, data, length) != EOK) { free(dataCopy); NETSTACK_LOGE("webSocketClient malloc copy error"); return WEBSOCKET_SEND_NO_MEMOERY_ERROR; } this->GetClientContext()->Push(dataCopy, length, protocol); return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } int WebSocketClient::Close(CloseOption options) { NETSTACK_LOGI("Close start"); if (this->GetClientContext() == nullptr) { return WebSocketErrorCode::WEBSOCKET_ERROR_NO_CLIENTCONTEX; } if (this->GetClientContext()->openStatus == 0) return WebSocketErrorCode::WEBSOCKET_ERROR_HAVE_NO_CONNECT; if (options.reason == nullptr || options.code == 0) { options.reason = ""; options.code = CLOSE_RESULT_FROM_CLIENT_CODE; } this->GetClientContext()->Close(static_cast(options.code), options.reason); return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } int WebSocketClient::Registcallback(OnOpenCallback onOpen, OnMessageCallback onMessage, OnErrorCallback onError, OnCloseCallback onClose) { onMessageCallback_ = onMessage; onCloseCallback_ = onClose; onErrorCallback_ = onError; onOpenCallback_ = onOpen; return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } int WebSocketClient::Destroy() { NETSTACK_LOGI("Destroy start"); if (this->GetClientContext()->GetContext() == nullptr) { return WebSocketErrorCode::WEBSOCKET_ERROR_HAVE_NO_CONNECT_CONTEXT; } this->GetClientContext()->SetContext(nullptr); lws_context_destroy(this->GetClientContext()->GetContext()); return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } } // namespace OHOS::NetStack::WebSocketClient