1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "rs_profiler_socket.h"
17 
18 #include <fcntl.h>
19 #include <netinet/in.h>
20 #include <netinet/tcp.h>
21 #include <securec.h>
22 #include <sys/select.h>
23 #include <sys/socket.h>
24 #include <sys/un.h>
25 #include <unistd.h>
26 
27 #include "rs_profiler_utils.h"
28 
29 namespace OHOS::Rosen {
30 
GetTimeoutDesc(uint32_t milliseconds)31 static timeval GetTimeoutDesc(uint32_t milliseconds)
32 {
33     const uint32_t millisecondsInSecond = 1000u;
34 
35     timeval timeout = {};
36     timeout.tv_sec = milliseconds / millisecondsInSecond;
37     timeout.tv_usec = (milliseconds % millisecondsInSecond) * millisecondsInSecond;
38     return timeout;
39 }
40 
GetTimeout(int32_t socket)41 static timeval GetTimeout(int32_t socket)
42 {
43     timeval timeout = {};
44     socklen_t size = sizeof(timeout);
45     getsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<char*>(&timeout), &size);
46     return timeout;
47 }
48 
SetTimeout(int32_t socket,const timeval & timeout)49 static void SetTimeout(int32_t socket, const timeval& timeout)
50 {
51     setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<const char*>(&timeout), sizeof(timeout));
52 }
53 
SetTimeout(int32_t socket,uint32_t milliseconds)54 static void SetTimeout(int32_t socket, uint32_t milliseconds)
55 {
56     SetTimeout(socket, GetTimeoutDesc(milliseconds));
57 }
58 
ToggleFlag(uint32_t flags,uint32_t flag,bool enable)59 static int32_t ToggleFlag(uint32_t flags, uint32_t flag, bool enable)
60 {
61     return enable ? (flags | flag) : (flags & ~flag);
62 }
63 
SetBlocking(int32_t socket,bool enable)64 static void SetBlocking(int32_t socket, bool enable)
65 {
66     fcntl(socket, F_SETFL, ToggleFlag(fcntl(socket, F_GETFL, 0), O_NONBLOCK, !enable));
67 }
68 
SetCloseOnExec(int32_t socket,bool enable)69 static void SetCloseOnExec(int32_t socket, bool enable)
70 {
71     fcntl(socket, F_SETFD, ToggleFlag(fcntl(socket, F_GETFD, 0), FD_CLOEXEC, enable));
72 }
73 
GetFdSet(int32_t socket)74 static fd_set GetFdSet(int32_t socket)
75 {
76     fd_set set;
77     FD_ZERO(&set);
78     FD_SET(socket, &set);
79     return set;
80 }
81 
IsFdSet(int32_t socket,const fd_set & set)82 static bool IsFdSet(int32_t socket, const fd_set& set)
83 {
84     return FD_ISSET(socket, &set);
85 }
86 
~Socket()87 Socket::~Socket()
88 {
89     Shutdown();
90 }
91 
GetState() const92 SocketState Socket::GetState() const
93 {
94     return state_;
95 }
96 
SetState(SocketState state)97 void Socket::SetState(SocketState state)
98 {
99     state_ = state;
100 }
101 
Shutdown()102 void Socket::Shutdown()
103 {
104     shutdown(socket_, SHUT_RDWR);
105     close(socket_);
106     socket_ = -1;
107 
108     shutdown(client_, SHUT_RDWR);
109     close(client_);
110     client_ = -1;
111 
112     state_ = SocketState::SHUTDOWN;
113 }
114 
Open(uint16_t port)115 void Socket::Open(uint16_t port)
116 {
117     socket_ = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
118     if (socket_ == -1) {
119         Shutdown();
120         return;
121     }
122 
123     const std::string socketName = "render_service_" + std::to_string(port);
124     sockaddr_un address {};
125     address.sun_family = AF_UNIX;
126     address.sun_path[0] = 0;
127     ::memmove_s(address.sun_path + 1, sizeof(address.sun_path) - 1, socketName.data(), socketName.size());
128 
129     const size_t addressSize = offsetof(sockaddr_un, sun_path) + socketName.size() + 1;
130     if (bind(socket_, reinterpret_cast<sockaddr*>(&address), addressSize) == -1) {
131         Shutdown();
132         return;
133     }
134 
135     const int32_t maxConnections = 5;
136     if (listen(socket_, maxConnections) != 0) {
137         Shutdown();
138         return;
139     }
140 
141     SetBlocking(socket_, false);
142     SetCloseOnExec(socket_, true);
143 
144     state_ = SocketState::CREATE;
145 }
146 
AcceptClient()147 void Socket::AcceptClient()
148 {
149     client_ = accept4(socket_, nullptr, nullptr, SOCK_CLOEXEC);
150     if (client_ == -1) {
151         if ((errno != EWOULDBLOCK) && (errno != EAGAIN) && (errno != EINTR)) {
152             Shutdown();
153         }
154     } else {
155         SetBlocking(client_, false);
156         SetCloseOnExec(client_, true);
157 
158         int32_t nodelay = 1;
159         setsockopt(client_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&nodelay), sizeof(nodelay));
160 
161         state_ = SocketState::ACCEPT;
162     }
163 }
164 
SendWhenReady(const void * data,size_t size)165 void Socket::SendWhenReady(const void* data, size_t size)
166 {
167     if (!data || (size == 0)) {
168         return;
169     }
170 
171     SetBlocking(client_, true);
172 
173     const timeval previousTimeout = GetTimeout(client_);
174 
175     const uint32_t timeoutMilliseconds = 40;
176     SetTimeout(client_, timeoutMilliseconds);
177 
178     const char* bytes = reinterpret_cast<const char*>(data);
179     size_t sent = 0;
180     while (sent < size) {
181         const ssize_t sentBytes = send(client_, bytes, size - sent, 0);
182         if ((sentBytes <= 0) && (errno != EINTR)) {
183             Shutdown();
184             return;
185         }
186         auto actualSentBytes = static_cast<size_t>(sentBytes);
187         sent += actualSentBytes;
188         bytes += actualSentBytes;
189     }
190 
191     SetTimeout(client_, previousTimeout);
192     SetBlocking(client_, false);
193 }
194 
Receive(void * data,size_t & size)195 bool Socket::Receive(void* data, size_t& size)
196 {
197     if (!data || (size == 0)) {
198         return true;
199     }
200 
201     SetBlocking(client_, false);
202 
203     const ssize_t receivedBytes = recv(client_, static_cast<char*>(data), size, 0);
204     if (receivedBytes > 0) {
205         size = static_cast<size_t>(receivedBytes);
206     } else {
207         size = 0;
208         if ((errno == EWOULDBLOCK) || (errno == EAGAIN) || (errno == EINTR)) {
209             return true;
210         }
211         Shutdown();
212         return false;
213     }
214     return true;
215 }
216 
ReceiveWhenReady(void * data,size_t size)217 bool Socket::ReceiveWhenReady(void* data, size_t size)
218 {
219     if (!data || (size == 0)) {
220         return true;
221     }
222 
223     const timeval previousTimeout = GetTimeout(client_);
224     const uint32_t bandwitdth = 10000; // KB/ms
225     const uint32_t timeoutPad = 100;
226     const uint32_t timeout = size / bandwitdth + timeoutPad;
227 
228     SetBlocking(client_, true);
229     SetTimeout(client_, timeout);
230 
231     size_t received = 0;
232     char* bytes = static_cast<char*>(data);
233     while (received < size) {
234         // receivedBytes can only be -1 or [0, size - received] (from recv man)
235         const ssize_t receivedBytes = recv(client_, bytes, size - received, 0);
236         if ((receivedBytes == -1) && (errno != EINTR)) {
237             Shutdown();
238             return false;
239         }
240 
241         // so receivedBytes here always [0, size - received]
242         // then received can't be > `size` and it can't be overflowed
243         auto actualReceivedBytes = static_cast<size_t>(receivedBytes);
244         received += actualReceivedBytes;
245         bytes += actualReceivedBytes;
246     }
247 
248     SetTimeout(client_, previousTimeout);
249     SetBlocking(client_, false);
250     return true;
251 }
252 
GetStatus(bool & readyToReceive,bool & readyToSend) const253 void Socket::GetStatus(bool& readyToReceive, bool& readyToSend) const
254 {
255     readyToReceive = false;
256     readyToSend = false;
257 
258     if (client_ == -1) {
259         return;
260     }
261 
262     fd_set send = GetFdSet(client_);
263     fd_set receive = GetFdSet(client_);
264 
265     constexpr uint32_t timeoutMilliseconds = 10;
266     timeval timeout = GetTimeoutDesc(timeoutMilliseconds);
267     if (select(client_ + 1, &receive, &send, nullptr, &timeout) == 0) {
268         return;
269     }
270 
271     readyToReceive = IsFdSet(client_, receive);
272     readyToSend = IsFdSet(client_, send);
273 }
274 
275 } // namespace OHOS::Rosen