1 /*
2  * Copyright (C) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "nstackx_socket.h"
17 #include "nstackx_log.h"
18 #include "nstackx_error.h"
19 #include "nstackx_util.h"
20 #include "nstackx_dev.h"
21 #include "locale.h"
22 #include "securec.h"
23 
24 #define NSTACKX_MAX_LISTEN_NUMBER 3
25 #define NSTACKX_TCP_SOCKET_BUFFER_SIZE (1 * 1024 * 1024)
26 
27 #define TAG "nStackXSocket"
28 
CloseSocket(Socket * socket)29 void CloseSocket(Socket *socket)
30 {
31     if (socket == NULL) {
32         return;
33     }
34     CloseSocketInner(socket->sockfd);
35     socket->sockfd = INVALID_SOCKET;
36     free(socket);
37 }
38 
GetTcpSocketBufSize(SocketDesc fd)39 static void GetTcpSocketBufSize(SocketDesc fd)
40 {
41     int32_t ret;
42     int32_t bufSize;
43     socklen_t optLen = sizeof(bufSize);
44 
45     ret = getsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufSize, &optLen);
46     if (ret < 0) {
47         LOGE(TAG, "getsockopt SO_SNDBUF failed");
48         return;
49     }
50     LOGD(TAG, "SO_SNDBUF = %d", bufSize);
51 
52     ret = getsockopt(fd, SOL_SOCKET, SO_RCVBUF, &bufSize, &optLen);
53     if (ret < 0) {
54         LOGE(TAG, "getsockopt SO_RCVBUF failed");
55         return;
56     }
57     LOGD(TAG, "SO_RCVBUF = %d", bufSize);
58 }
59 
SetTcpSocketBufSize(SocketDesc fd,int32_t bufSize)60 static int32_t SetTcpSocketBufSize(SocketDesc fd, int32_t bufSize)
61 {
62     int32_t ret;
63     socklen_t optLen = sizeof(bufSize);
64 
65     if (bufSize < 0) {
66         return NSTACKX_EFAILED;
67     }
68 
69     GetTcpSocketBufSize(fd);
70     ret = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufSize, optLen);
71     if (ret < 0) {
72         LOGE(TAG, "setsockopt SO_SNDBUF failed");
73         return NSTACKX_EFAILED;
74     }
75     LOGD(TAG, "setsockopt SO_SNDBUF = %d", bufSize);
76 
77     ret = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &bufSize, optLen);
78     if (ret < 0) {
79         LOGE(TAG, "setsockopt SO_RCVBUF failed");
80         return NSTACKX_EFAILED;
81     }
82     LOGD(TAG, "setsockopt SO_RCVBUF = %d", bufSize);
83     GetTcpSocketBufSize(fd);
84     return NSTACKX_EOK;
85 }
86 
ConnectTcpServerWithTargetDev(Socket * clientSocket,const struct sockaddr_in * sockAddr,const char * localInterface)87 static int32_t ConnectTcpServerWithTargetDev(Socket *clientSocket, const struct sockaddr_in *sockAddr,
88                                              const char *localInterface)
89 {
90     socklen_t addrLen = sizeof(struct sockaddr_in);
91 
92     clientSocket->sockfd = socket(AF_INET, SOCK_STREAM, 0);
93     if (clientSocket->sockfd == INVALID_SOCKET) {
94         LOGE(TAG, "socket create failed, error :%d", GetErrno());
95         return NSTACKX_EFAILED;
96     }
97     if (SetTcpSocketBufSize(clientSocket->sockfd, NSTACKX_TCP_SOCKET_BUFFER_SIZE) != NSTACKX_EOK) {
98         LOGE(TAG, "set socket buf failed");
99     }
100     if (SetSocketNonBlock(clientSocket->sockfd) != NSTACKX_EOK) {
101         LOGE(TAG, "set socket nonblock failed");
102     }
103     if (localInterface == NULL) {
104         BindToDevInTheSameLan(clientSocket->sockfd, sockAddr);
105     } else {
106         LOGI(TAG, "bind to target interface %s", localInterface);
107         if (BindToTargetDev(clientSocket->sockfd, localInterface) != NSTACKX_EOK) {
108             LOGE(TAG, "can't bind to target interface %s", localInterface);
109         } else {
110             LOGI(TAG, "bind to target interface %s successfully", localInterface);
111         }
112     }
113     int32_t ret = connect(clientSocket->sockfd, (struct sockaddr *)sockAddr, addrLen);
114     if (ret != 0) {
115         if (!SocketOpInProgress()) {
116             LOGE(TAG, "connect error, %d", GetErrno());
117             goto FAIL_SOCKET;
118         }
119     }
120     LOGI(TAG, "connect success");
121 
122     clientSocket->dstAddr = *sockAddr;
123     return NSTACKX_EOK;
124 
125 FAIL_SOCKET:
126     CloseSocketInner(clientSocket->sockfd);
127     clientSocket->sockfd = INVALID_SOCKET;
128     return NSTACKX_EFAILED;
129 }
130 
ConnectUdpServerWithTargetDev(Socket * clientSocket,const struct sockaddr_in * sockAddr,const char * localInterface)131 static int32_t ConnectUdpServerWithTargetDev(Socket *clientSocket, const struct sockaddr_in *sockAddr,
132                                              const char *localInterface)
133 {
134     int32_t ret = 0;
135     struct sockaddr_in tmpAddr;
136     socklen_t srcAddrLen = sizeof(struct sockaddr_in);
137     clientSocket->sockfd = socket(AF_INET, SOCK_DGRAM, 0);
138     if (clientSocket->sockfd == INVALID_SOCKET) {
139         LOGE(TAG, "socket create failed, error :%d", GetErrno());
140         return NSTACKX_EFAILED;
141     }
142     if (SetSocketNonBlock(clientSocket->sockfd) != NSTACKX_EOK) {
143         LOGE(TAG, "set socket nonblock failed");
144         goto FAIL_SOCKET;
145     }
146 
147     if (localInterface == NULL) {
148         BindToDevInTheSameLan(clientSocket->sockfd, sockAddr);
149     } else {
150         if (BindToTargetDev(clientSocket->sockfd, localInterface) != NSTACKX_EOK) {
151             LOGE(TAG, "can't bind to target interface %s", localInterface);
152         } else {
153             LOGI(TAG, "bind to target interface %s successfully", localInterface);
154         }
155     }
156     ret = connect(clientSocket->sockfd, (struct sockaddr *)sockAddr, sizeof(struct sockaddr));
157     if (ret != 0) {
158         LOGE(TAG, "connect to udp server failed %d", GetErrno());
159         goto FAIL_SOCKET;
160     }
161 
162     (void)memset_s(&tmpAddr, sizeof(tmpAddr), 0, sizeof(tmpAddr));
163     ret = getsockname(clientSocket->sockfd, (struct sockaddr *)&tmpAddr, &srcAddrLen);
164     if (ret != 0) {
165         LOGE(TAG, "getsockname failed %d", GetErrno());
166         goto FAIL_SOCKET;
167     }
168     clientSocket->dstAddr = *sockAddr;
169     clientSocket->srcAddr = tmpAddr;
170     return NSTACKX_EOK;
171 FAIL_SOCKET:
172     CloseSocketInner(clientSocket->sockfd);
173     clientSocket->sockfd = INVALID_SOCKET;
174     return NSTACKX_EFAILED;
175 }
176 
CreateTcpServer(Socket * serverSocket,const struct sockaddr_in * sockAddr)177 static int32_t CreateTcpServer(Socket *serverSocket, const struct sockaddr_in *sockAddr)
178 {
179     int32_t reuse = 1;
180     struct sockaddr_in localAddr;
181     socklen_t len = sizeof(localAddr);
182 
183     serverSocket->sockfd = socket(AF_INET, SOCK_STREAM, 0);
184     if (serverSocket->sockfd == INVALID_SOCKET) {
185         LOGE(TAG, "create socket failed, error :%d", GetErrno());
186         return NSTACKX_EFAILED;
187     }
188     if (SetSocketNonBlock(serverSocket->sockfd) != NSTACKX_EOK) {
189         LOGE(TAG, "set socket nonblock failed");
190         goto FAIL_SOCKET;
191     }
192 
193     if (setsockopt(serverSocket->sockfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) != 0) {
194         LOGE(TAG, "Failed to set server socket! error :%d", GetErrno());
195         goto FAIL_SOCKET;
196     }
197 
198     (void)memset_s(&localAddr, sizeof(localAddr), 0, sizeof(localAddr));
199     /* Bind to ANY source ip address and random port number */
200     localAddr.sin_family = AF_INET;
201     localAddr.sin_port = sockAddr->sin_port;
202     if (sockAddr->sin_addr.s_addr != 0) {
203         localAddr.sin_addr.s_addr = sockAddr->sin_addr.s_addr;
204     } else {
205         localAddr.sin_addr.s_addr = INADDR_ANY;
206     }
207 
208     if (bind(serverSocket->sockfd, (struct sockaddr *)&localAddr, len) != 0) {
209         LOGE(TAG, "Failed to bind socket error :%d", GetErrno());
210         goto FAIL_SOCKET;
211     }
212 
213     if (sockAddr->sin_addr.s_addr != 0 &&
214         BindToDevice(serverSocket->sockfd, sockAddr) != NSTACKX_EOK) {
215         LOGE(TAG, "Failed to bind socket to device");
216     }
217 
218     if (getsockname(serverSocket->sockfd, (struct sockaddr *)&(serverSocket->srcAddr), &len) != 0) {
219         LOGE(TAG, "Failed to get socket name! error :%d", GetErrno());
220         goto FAIL_SOCKET;
221     }
222 
223     if (listen(serverSocket->sockfd, NSTACKX_MAX_LISTEN_NUMBER) != 0) {
224         LOGE(TAG, "Failed to listen TCP port! error :%d", GetErrno());
225         goto FAIL_SOCKET;
226     }
227 
228     /* Note: Here we rely on that an accepted socket will inherit SO_SNDBUF and SO_RCVBUF
229         options from the listening socket. */
230     if (SetTcpSocketBufSize(serverSocket->sockfd, NSTACKX_TCP_SOCKET_BUFFER_SIZE) != NSTACKX_EOK) {
231         LOGE(TAG, "Failed to set socket buff size:%u", NSTACKX_TCP_SOCKET_BUFFER_SIZE);
232     }
233 
234     return NSTACKX_EOK;
235 FAIL_SOCKET:
236     CloseSocketInner(serverSocket->sockfd);
237     serverSocket->sockfd = INVALID_SOCKET;
238     return NSTACKX_EFAILED;
239 }
240 
CreateUdpServer(Socket * serverSocket,const struct sockaddr_in * sockAddr)241 static int32_t CreateUdpServer(Socket *serverSocket, const struct sockaddr_in *sockAddr)
242 {
243     if (sockAddr == NULL) {
244         LOGE(TAG, "sockAddr is null");
245         return NSTACKX_EFAILED;
246     }
247     struct sockaddr_in localAddr;
248     socklen_t len = sizeof(localAddr);
249     serverSocket->sockfd = socket(AF_INET, SOCK_DGRAM, 0);
250     if (serverSocket->sockfd == INVALID_SOCKET) {
251         LOGE(TAG, "create socket failed, error :%d", GetErrno());
252         return NSTACKX_EFAILED;
253     }
254 
255     if (SetSocketNonBlock(serverSocket->sockfd) != NSTACKX_EOK) {
256         LOGE(TAG, "set socket nonblock failed");
257         goto FAIL_SOCKET;
258     }
259 
260     (void)memset_s(&localAddr, sizeof(localAddr), 0, sizeof(localAddr));
261     /* Bind to ANY source ip address and random port number */
262     localAddr.sin_family = AF_INET;
263     localAddr.sin_port = sockAddr->sin_port;
264     if (sockAddr->sin_addr.s_addr != 0) {
265         localAddr.sin_addr.s_addr = sockAddr->sin_addr.s_addr;
266     } else {
267         localAddr.sin_addr.s_addr = INADDR_ANY;
268     }
269     if (bind(serverSocket->sockfd, (struct sockaddr *)&localAddr, len) != 0) {
270         LOGE(TAG, "Failed to bind socket, error :%d", GetErrno());
271         goto FAIL_SOCKET;
272     }
273 
274     if (sockAddr->sin_addr.s_addr != 0 &&
275         BindToDevice(serverSocket->sockfd, sockAddr) != NSTACKX_EOK) {
276         LOGE(TAG, "Failed to bind socket to device");
277     }
278 
279     if (getsockname(serverSocket->sockfd, (struct sockaddr *)(&serverSocket->srcAddr), &len) != 0) {
280         LOGE(TAG, "Failed to get socket name! error :%d", GetErrno());
281         goto FAIL_SOCKET;
282     }
283 
284     return NSTACKX_EOK;
285 FAIL_SOCKET:
286     CloseSocketInner(serverSocket->sockfd);
287     serverSocket->sockfd = INVALID_SOCKET;
288     return NSTACKX_EFAILED;
289 }
290 
ClientSocketWithTargetDev(SocketProtocol protocol,const struct sockaddr_in * sockAddr,const char * localInterface)291 Socket *ClientSocketWithTargetDev(SocketProtocol protocol, const struct sockaddr_in *sockAddr,
292                                   const char *localInterface)
293 {
294     int32_t ret;
295     if (sockAddr == NULL) {
296         return NULL;
297     }
298     Socket *socket = calloc(1, sizeof(Socket));
299     if (socket == NULL) {
300         LOGE(TAG, "malloc Socket failed\n");
301         return NULL;
302     }
303 
304     switch (protocol) {
305         case NSTACKX_PROTOCOL_TCP:
306             socket->protocol = NSTACKX_PROTOCOL_TCP;
307             ret = ConnectTcpServerWithTargetDev(socket, sockAddr, localInterface);
308             break;
309         case NSTACKX_PROTOCOL_UDP:
310             socket->protocol = NSTACKX_PROTOCOL_UDP;
311             ret = ConnectUdpServerWithTargetDev(socket, sockAddr, localInterface);
312             break;
313         case NSTACKX_PROTOCOL_D2D:
314             LOGE(TAG, "d2d not support");
315             ret = NSTACKX_EFAILED;
316             break;
317         default:
318             LOGE(TAG, "protocol not support");
319             ret = NSTACKX_EFAILED;
320             break;
321     }
322 
323     if (ret != NSTACKX_EOK) {
324         LOGE(TAG, "Create client socket failed! %d", ret);
325         free(socket);
326         return NULL;
327     }
328     socket->isServer = NSTACKX_FALSE;
329     return socket;
330 }
331 
ClientSocket(SocketProtocol protocol,const struct sockaddr_in * sockAddr)332 Socket *ClientSocket(SocketProtocol protocol, const struct sockaddr_in *sockAddr)
333 {
334     return ClientSocketWithTargetDev(protocol, sockAddr, NULL);
335 }
336 
ServerSocket(SocketProtocol protocol,const struct sockaddr_in * sockAddr)337 Socket *ServerSocket(SocketProtocol protocol, const struct sockaddr_in *sockAddr)
338 {
339     int32_t ret;
340     if (sockAddr == NULL) {
341         return NULL;
342     }
343     Socket *socket = calloc(1, sizeof(Socket));
344     if (socket == NULL) {
345         LOGE(TAG, "malloc Socket failed\n");
346         return NULL;
347     }
348 
349     switch (protocol) {
350         case NSTACKX_PROTOCOL_TCP:
351             socket->protocol = NSTACKX_PROTOCOL_TCP;
352             ret = CreateTcpServer(socket, sockAddr);
353             break;
354         case NSTACKX_PROTOCOL_UDP:
355             socket->protocol = NSTACKX_PROTOCOL_UDP;
356             ret = CreateUdpServer(socket, sockAddr);
357             break;
358         case NSTACKX_PROTOCOL_D2D:
359             socket->protocol = NSTACKX_PROTOCOL_D2D;
360             ret = NSTACKX_EFAILED;
361             LOGE(TAG, "d2d not support");
362             break;
363         default:
364             LOGE(TAG, "protocol not support");
365             ret = NSTACKX_EFAILED;
366             break;
367     }
368 
369     if (ret != NSTACKX_EOK) {
370         LOGE(TAG, "Create server socket failed! %d", ret);
371         free(socket);
372         return NULL;
373     }
374     socket->isServer = NSTACKX_TRUE;
375     return socket;
376 }
377 
CheckAcceptSocketValid(const Socket * serverSocket)378 static int32_t CheckAcceptSocketValid(const Socket *serverSocket)
379 {
380     if (serverSocket == NULL || serverSocket->isServer == NSTACKX_FALSE ||
381         serverSocket->protocol != NSTACKX_PROTOCOL_TCP) {
382         LOGE(TAG, "invalue Socket for accept");
383         return NSTACKX_EINVAL;
384     }
385     return NSTACKX_EOK;
386 }
387 
SetAcceptSocket(SocketDesc acceptFd)388 static int32_t SetAcceptSocket(SocketDesc acceptFd)
389 {
390     struct sockaddr_in localAddr;
391     socklen_t localAddrLen = sizeof(localAddr);
392     (void)memset_s(&localAddr, localAddrLen, 0, localAddrLen);
393     if (getsockname(acceptFd, (struct sockaddr *)&localAddr, &localAddrLen) != 0) {
394         LOGE(TAG, "get socket name fail %d", GetErrno());
395         return NSTACKX_EFAILED;
396     }
397     /* It will always failed on devices without system authority, such as third-party devices. */
398     if (BindToDevice(acceptFd, &localAddr) != NSTACKX_EOK) {
399         LOGW(TAG, "Accept client bind to device failed");
400     }
401 
402     if (SetSocketNonBlock(acceptFd) != NSTACKX_EOK) {
403         LOGE(TAG, "set socket nonblock failed");
404         return NSTACKX_EFAILED;
405     }
406     return NSTACKX_EOK;
407 }
408 
AcceptSocket(Socket * serverSocket)409 Socket *AcceptSocket(Socket *serverSocket)
410 {
411     struct sockaddr_in clientAddr;
412     socklen_t addrLen = sizeof(clientAddr);
413     (void)memset_s(&clientAddr, addrLen, 0, addrLen);
414 
415     if (CheckAcceptSocketValid(serverSocket) != NSTACKX_EOK) {
416         LOGE(TAG, "invalue Socket for accept \n");
417         return NULL;
418     }
419 
420     Socket *clientSocket = calloc(1, sizeof(Socket));
421     if (clientSocket == NULL) {
422         LOGE(TAG, "client socket malloc failed\n");
423         return NULL;
424     }
425     clientSocket->protocol = NSTACKX_PROTOCOL_TCP;
426     clientSocket->isServer = NSTACKX_FALSE;
427     clientSocket->sockfd = accept(serverSocket->sockfd, (struct sockaddr *)&clientAddr, &addrLen);
428     if (clientSocket->sockfd == INVALID_SOCKET) {
429         LOGE(TAG, "accept return error: %d", GetErrno());
430         goto L_SOCKET_FAIL;
431     }
432 
433     if (SetAcceptSocket(clientSocket->sockfd) != NSTACKX_EOK) {
434         LOGE(TAG, "set accept socket failed");
435         goto L_SOCKET_FAIL;
436     }
437 
438     clientSocket->dstAddr = clientAddr;
439 
440     return clientSocket;
441 L_SOCKET_FAIL:
442     if (clientSocket->sockfd != INVALID_SOCKET) {
443         CloseSocketInner(clientSocket->sockfd);
444         clientSocket->sockfd = INVALID_SOCKET;
445     }
446     free(clientSocket);
447     return NULL;
448 }
449 
CheckSocketError(void)450 int32_t CheckSocketError(void)
451 {
452     int32_t ret;
453     if (SocketOpWouldBlock()) {
454         ret = NSTACKX_EAGAIN;
455     } else {
456         LOGE(TAG, "sendto/recvfrom error: %d", GetErrno());
457         ret = NSTACKX_EFAILED;
458     }
459     return ret;
460 }
461 
SocketSendUdp(const Socket * socket,const uint8_t * buffer,size_t length)462 static int32_t SocketSendUdp(const Socket *socket, const uint8_t *buffer, size_t length)
463 {
464     socklen_t dstAddrLen = sizeof(struct sockaddr_in);
465 
466     int32_t ret = (int32_t)sendto(socket->sockfd, buffer, length, 0, (struct sockaddr *)&socket->dstAddr, dstAddrLen);
467     if (ret <= 0) {
468         ret = CheckSocketError();
469     }
470     return ret;
471 }
472 
SocketSend(const Socket * socket,const uint8_t * buffer,size_t length)473 int32_t SocketSend(const Socket *socket, const uint8_t *buffer, size_t length)
474 {
475     int32_t ret = NSTACKX_EFAILED;
476 
477     if (socket == NULL || buffer == NULL) {
478         LOGE(TAG, "invalue socket input");
479         return ret;
480     }
481 
482     if (socket->protocol == NSTACKX_PROTOCOL_TCP) {
483         ret = (int32_t)send(socket->sockfd, buffer, length, 0);
484     } else if (socket->protocol == NSTACKX_PROTOCOL_UDP) {
485         ret = SocketSendUdp(socket, buffer, length);
486     } else {
487         LOGE(TAG, "protocol not support %d", socket->protocol);
488     }
489 
490     return ret;
491 }
492 
SocketRecvTcp(const Socket * socket,uint8_t * buffer,size_t length,struct sockaddr_in * srcAddr,const socklen_t * addrLen)493 static int32_t SocketRecvTcp(const Socket *socket, uint8_t *buffer, size_t length, struct sockaddr_in *srcAddr,
494                              const socklen_t *addrLen)
495 {
496     int32_t ret = (int32_t)recv(socket->sockfd, buffer, length, 0);
497     if (srcAddr != NULL && *addrLen >= (socklen_t)sizeof(struct sockaddr_in)) {
498         *srcAddr = socket->dstAddr;
499     }
500     return ret;
501 }
502 
SocketRecvUdp(const Socket * socket,uint8_t * buffer,size_t length,struct sockaddr_in * srcAddr,const socklen_t * addrLen)503 static int32_t SocketRecvUdp(const Socket *socket, uint8_t *buffer, size_t length, struct sockaddr_in *srcAddr,
504                              const socklen_t *addrLen)
505 {
506     struct sockaddr_in addr;
507     socklen_t len = sizeof(struct sockaddr_in);
508     (void)memset_s(&addr, sizeof(addr), 0, sizeof(addr));
509     int32_t ret = (int32_t)recvfrom(socket->sockfd, buffer, length, 0, (struct sockaddr *)&addr, &len);
510     if (ret < 0) {
511         ret = CheckSocketError();
512     } else if (ret == 0 || addr.sin_port == 0 || addr.sin_family != AF_INET) {
513         ret = NSTACKX_EAGAIN;
514     } else {
515         if (srcAddr != NULL && *addrLen >= (socklen_t)sizeof(struct sockaddr_in)) {
516             *srcAddr = addr;
517         }
518     }
519     return ret;
520 }
521 
SocketRecv(Socket * socket,uint8_t * buffer,size_t length,struct sockaddr_in * srcAddr,const socklen_t * addrLen)522 int32_t SocketRecv(Socket *socket, uint8_t *buffer, size_t length, struct sockaddr_in *srcAddr,
523                    const socklen_t *addrLen)
524 {
525     int32_t ret = NSTACKX_EFAILED;
526 
527     if (socket == NULL) {
528         LOGE(TAG, "invalue socket input");
529         return ret;
530     }
531 
532     if (socket->protocol == NSTACKX_PROTOCOL_TCP) {
533         ret = SocketRecvTcp(socket, buffer, length, srcAddr, addrLen);
534     } else if (socket->protocol == NSTACKX_PROTOCOL_UDP) {
535         ret = SocketRecvUdp(socket, buffer, length, srcAddr, addrLen);
536     } else {
537         LOGE(TAG, "protocol not support %d", socket->protocol);
538     }
539 
540     return ret;
541 }
542