/* * Copyright (c) 2021 Huawei Device Co., Ltd. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fd_holder.h" #include <stdio.h> #include <errno.h> #include "beget_ext.h" #include "fd_holder_internal.h" #include "init_utils.h" #include "securec.h" static int BuildClientSocket(void) { int sockFd; sockFd = socket(AF_UNIX, SOCK_DGRAM | SOCK_CLOEXEC, 0); BEGET_ERROR_CHECK(sockFd >= 0, return -1, "Failed to build socket, err = %d", errno); struct sockaddr_un addr; (void)memset_s(&addr, sizeof(addr), 0, sizeof(addr)); addr.sun_family = AF_UNIX; int ret = strncpy_s(addr.sun_path, sizeof(addr.sun_path), INIT_HOLDER_SOCKET_PATH, strlen(INIT_HOLDER_SOCKET_PATH)); BEGET_ERROR_CHECK(ret == 0, close(sockFd); return -1, "Failed to build socket path"); socklen_t len = (socklen_t)(offsetof(struct sockaddr_un, sun_path) + strlen(addr.sun_path) + 1); ret = connect(sockFd, (struct sockaddr *)&addr, len); BEGET_ERROR_CHECK(ret >= 0, close(sockFd); return -1, "Failed to connect to socket, err = %d", errno); return sockFd; } STATIC int BuildSendData(char *buffer, size_t size, const char *serviceName, bool hold, bool poll) { if (buffer == NULL || size == 0 || serviceName == 0) { return -1; } if (!hold && poll) { BEGET_LOGE("Get fd with poll set, invalid parameter"); return -1; } char *holdString = ACTION_HOLD; if (!hold) { holdString = ACTION_GET; } char *pollString = WITHPOLL; if (!poll) { pollString = WITHOUTPOLL; } if (snprintf_s(buffer, size, size - 1, "%s|%s|%s", serviceName, holdString, pollString) == -1) { BEGET_LOGE("Failed to build send data"); return -1; } return 0; } static int ServiceSendFds(const char *serviceName, int *fds, int fdCount, bool doPoll) { int sock = BuildClientSocket(); BEGET_CHECK(sock >= 0, return -1); struct iovec iovec = {}; struct msghdr msghdr = { .msg_iov = &iovec, .msg_iovlen = 1, }; char sendBuffer[MAX_FD_HOLDER_BUFFER] = {}; int ret = BuildSendData(sendBuffer, sizeof(sendBuffer), serviceName, true, doPoll); BEGET_ERROR_CHECK(ret >= 0, close(sock); return -1, "Failed to build send data"); BEGET_LOGV("Send data: [%s]", sendBuffer); iovec.iov_base = sendBuffer; iovec.iov_len = strlen(sendBuffer); if (BuildControlMessage(&msghdr, fds, fdCount, true) < 0) { BEGET_LOGE("Failed to build control message"); if (msghdr.msg_control != NULL) { free(msghdr.msg_control); msghdr.msg_control = NULL; } msghdr.msg_controllen = 0; close(sock); return -1; } if (TEMP_FAILURE_RETRY(sendmsg(sock, &msghdr, MSG_NOSIGNAL)) < 0) { BEGET_LOGE("Failed to send fds to init, err = %d", errno); if (msghdr.msg_control != NULL) { free(msghdr.msg_control); msghdr.msg_control = NULL; } msghdr.msg_controllen = 0; close(sock); return -1; } if (msghdr.msg_control != NULL) { free(msghdr.msg_control); msghdr.msg_control = NULL; } msghdr.msg_controllen = 0; BEGET_LOGI("Send fds done"); close(sock); return 0; } int ServiceSaveFd(const char *serviceName, int *fds, int fdCount) { // Sanity checks if (serviceName == NULL || fds == NULL || fdCount < 0 || fdCount > MAX_HOLD_FDS) { BEGET_LOGE("Invalid parameters"); return -1; } return ServiceSendFds(serviceName, fds, fdCount, false); } int ServiceSaveFdWithPoll(const char *serviceName, int *fds, int fdCount) { // Sanity checks if (serviceName == NULL || fds == NULL || fdCount < 0 || fdCount > MAX_HOLD_FDS) { BEGET_LOGE("Invalid parameters"); return -1; } return ServiceSendFds(serviceName, fds, fdCount, true); } int *ServiceGetFd(const char *serviceName, size_t *outfdCount) { if (serviceName == NULL || outfdCount == NULL) { BEGET_LOGE("Invalid parameters"); return NULL; } char path[MAX_FD_HOLDER_BUFFER] = {}; int ret = snprintf_s(path, MAX_FD_HOLDER_BUFFER, MAX_FD_HOLDER_BUFFER - 1, ENV_FD_HOLD_PREFIX"%s", serviceName); BEGET_ERROR_CHECK(ret > 0, return NULL, "Failed snprintf_s err=%d", errno); const char *value = getenv(path); if (value == NULL) { BEGET_LOGE("Cannot get env %s\n", path); return NULL; } char fdBuffer[MAX_FD_HOLDER_BUFFER] = {}; ret = strncpy_s(fdBuffer, MAX_FD_HOLDER_BUFFER - 1, value, strlen(value)); BEGET_ERROR_CHECK(ret == 0, return NULL, "Failed strncpy_s err=%d", errno); BEGET_LOGV("fds = %s", fdBuffer); int fdCount = 0; char **fdList = SplitStringExt(fdBuffer, " ", &fdCount, MAX_HOLD_FDS); BEGET_ERROR_CHECK(fdList != NULL, return NULL, "Cannot get fd list"); int *fds = calloc((size_t)fdCount, sizeof(int)); BEGET_ERROR_CHECK(fds != NULL, FreeStringVector(fdList, fdCount); *outfdCount = 0; return NULL, "Allocate memory for fd failed. err = %d", errno); bool encounterError = false; for (int i = 0; i < fdCount; i++) { errno = 0; fds[i] = (int)strtol(fdList[i], NULL, DECIMAL_BASE); BEGET_ERROR_CHECK(errno == 0, encounterError = true; break, "Failed to convert \' %s \' to fd number", fdList[i]); } if (encounterError) { free(fds); fds = NULL; fdCount = 0; } *outfdCount = fdCount; FreeStringVector(fdList, fdCount); return fds; }