1/*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2013 The FreeBSD Foundation
5 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6 * All rights reserved.
7 *
8 * This software was developed by Pawel Jakub Dawidek under sponsorship from
9 * the FreeBSD Foundation.
10 *
11 * Redistribution and use in source and binary forms, with or without
12 * modification, are permitted provided that the following conditions
13 * are met:
14 * 1. Redistributions of source code must retain the above copyright
15 *    notice, this list of conditions and the following disclaimer.
16 * 2. Redistributions in binary form must reproduce the above copyright
17 *    notice, this list of conditions and the following disclaimer in the
18 *    documentation and/or other materials provided with the distribution.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30 * SUCH DAMAGE.
31 */
32
33#include <sys/cdefs.h>
34__FBSDID("$FreeBSD$");
35
36#include <sys/param.h>
37#include <sys/socket.h>
38
39#include <errno.h>
40#include <fcntl.h>
41#include <stdbool.h>
42#include <stdint.h>
43#include <stdlib.h>
44#include <string.h>
45#include <unistd.h>
46
47#ifdef HAVE_PJDLOG
48#include <pjdlog.h>
49#endif
50
51#include "common_impl.h"
52#include "msgio.h"
53
54#ifndef	HAVE_PJDLOG
55#include <assert.h>
56#define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
57#define	PJDLOG_RASSERT(expr, ...)	assert(expr)
58#define	PJDLOG_ABORT(...)		abort()
59#endif
60
61/*
62 * To work around limitations in 32-bit emulation on 64-bit kernels, use a
63 * machine-independent limit on the number of FDs per message.  Each control
64 * message contains 1 FD and requires 12 bytes for the header, 4 pad bytes,
65 * 4 bytes for the descriptor, and another 4 pad bytes.
66 */
67#define	PKG_MAX_SIZE	(MCLBYTES / 24)
68
69static int
70msghdr_add_fd(struct cmsghdr *cmsg, int fd)
71{
72
73	PJDLOG_ASSERT(fd >= 0);
74
75	cmsg->cmsg_level = SOL_SOCKET;
76	cmsg->cmsg_type = SCM_RIGHTS;
77	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
78	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
79
80	return (0);
81}
82
83static int
84msghdr_get_fd(struct cmsghdr *cmsg)
85{
86	int fd;
87
88	if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
89	    cmsg->cmsg_type != SCM_RIGHTS ||
90	    cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
91		errno = EINVAL;
92		return (-1);
93	}
94
95	bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
96#ifndef MSG_CMSG_CLOEXEC
97	/*
98	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
99	 * close-on-exec flag atomically, but we still want to set it for
100	 * consistency.
101	 */
102	(void) fcntl(fd, F_SETFD, FD_CLOEXEC);
103#endif
104
105	return (fd);
106}
107
108static void
109fd_wait(int fd, bool doread)
110{
111	fd_set fds;
112
113	PJDLOG_ASSERT(fd >= 0);
114
115	FD_ZERO(&fds);
116	FD_SET(fd, &fds);
117	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
118	    NULL, NULL);
119}
120
121static int
122msg_recv(int sock, struct msghdr *msg)
123{
124	int flags;
125
126	PJDLOG_ASSERT(sock >= 0);
127
128#ifdef MSG_CMSG_CLOEXEC
129	flags = MSG_CMSG_CLOEXEC;
130#else
131	flags = 0;
132#endif
133
134	for (;;) {
135		fd_wait(sock, true);
136		if (recvmsg(sock, msg, flags) == -1) {
137			if (errno == EINTR)
138				continue;
139			return (-1);
140		}
141		break;
142	}
143
144	return (0);
145}
146
147static int
148msg_send(int sock, const struct msghdr *msg)
149{
150
151	PJDLOG_ASSERT(sock >= 0);
152
153	for (;;) {
154		fd_wait(sock, false);
155		if (sendmsg(sock, msg, 0) == -1) {
156			if (errno == EINTR)
157				continue;
158			return (-1);
159		}
160		break;
161	}
162
163	return (0);
164}
165
166int
167cred_send(int sock)
168{
169	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
170	struct msghdr msg;
171	struct cmsghdr *cmsg;
172	struct iovec iov;
173	uint8_t dummy;
174
175	bzero(credbuf, sizeof(credbuf));
176	bzero(&msg, sizeof(msg));
177	bzero(&iov, sizeof(iov));
178
179	/*
180	 * XXX: We send one byte along with the control message, because
181	 *      setting msg_iov to NULL only works if this is the first
182	 *      packet send over the socket. Once we send some data we
183	 *      won't be able to send credentials anymore. This is most
184	 *      likely a kernel bug.
185	 */
186	dummy = 0;
187	iov.iov_base = &dummy;
188	iov.iov_len = sizeof(dummy);
189
190	msg.msg_iov = &iov;
191	msg.msg_iovlen = 1;
192	msg.msg_control = credbuf;
193	msg.msg_controllen = sizeof(credbuf);
194
195	cmsg = CMSG_FIRSTHDR(&msg);
196	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
197	cmsg->cmsg_level = SOL_SOCKET;
198	cmsg->cmsg_type = SCM_CREDS;
199
200	if (msg_send(sock, &msg) == -1)
201		return (-1);
202
203	return (0);
204}
205
206int
207cred_recv(int sock, struct cmsgcred *cred)
208{
209	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
210	struct msghdr msg;
211	struct cmsghdr *cmsg;
212	struct iovec iov;
213	uint8_t dummy;
214
215	bzero(credbuf, sizeof(credbuf));
216	bzero(&msg, sizeof(msg));
217	bzero(&iov, sizeof(iov));
218
219	iov.iov_base = &dummy;
220	iov.iov_len = sizeof(dummy);
221
222	msg.msg_iov = &iov;
223	msg.msg_iovlen = 1;
224	msg.msg_control = credbuf;
225	msg.msg_controllen = sizeof(credbuf);
226
227	if (msg_recv(sock, &msg) == -1)
228		return (-1);
229
230	cmsg = CMSG_FIRSTHDR(&msg);
231	if (cmsg == NULL ||
232	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
233	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
234		errno = EINVAL;
235		return (-1);
236	}
237	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
238
239	return (0);
240}
241
242static int
243fd_package_send(int sock, const int *fds, size_t nfds)
244{
245	struct msghdr msg;
246	struct cmsghdr *cmsg;
247	struct iovec iov;
248	unsigned int i;
249	int serrno, ret;
250	uint8_t dummy;
251
252	PJDLOG_ASSERT(sock >= 0);
253	PJDLOG_ASSERT(fds != NULL);
254	PJDLOG_ASSERT(nfds > 0);
255
256	bzero(&msg, sizeof(msg));
257
258	/*
259	 * XXX: Look into cred_send function for more details.
260	 */
261	dummy = 0;
262	iov.iov_base = &dummy;
263	iov.iov_len = sizeof(dummy);
264
265	msg.msg_iov = &iov;
266	msg.msg_iovlen = 1;
267	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
268	msg.msg_control = calloc(1, msg.msg_controllen);
269	if (msg.msg_control == NULL)
270		return (-1);
271
272	ret = -1;
273
274	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
275	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
276		if (msghdr_add_fd(cmsg, fds[i]) == -1)
277			goto end;
278	}
279
280	if (msg_send(sock, &msg) == -1)
281		goto end;
282
283	ret = 0;
284end:
285	serrno = errno;
286	free(msg.msg_control);
287	errno = serrno;
288	return (ret);
289}
290
291static int
292fd_package_recv(int sock, int *fds, size_t nfds)
293{
294	struct msghdr msg;
295	struct cmsghdr *cmsg;
296	unsigned int i;
297	int serrno, ret;
298	struct iovec iov;
299	uint8_t dummy;
300
301	PJDLOG_ASSERT(sock >= 0);
302	PJDLOG_ASSERT(nfds > 0);
303	PJDLOG_ASSERT(fds != NULL);
304
305	bzero(&msg, sizeof(msg));
306	bzero(&iov, sizeof(iov));
307
308	/*
309	 * XXX: Look into cred_send function for more details.
310	 */
311	iov.iov_base = &dummy;
312	iov.iov_len = sizeof(dummy);
313
314	msg.msg_iov = &iov;
315	msg.msg_iovlen = 1;
316	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
317	msg.msg_control = calloc(1, msg.msg_controllen);
318	if (msg.msg_control == NULL)
319		return (-1);
320
321	ret = -1;
322
323	if (msg_recv(sock, &msg) == -1)
324		goto end;
325
326	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
327	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
328		fds[i] = msghdr_get_fd(cmsg);
329		if (fds[i] < 0)
330			break;
331	}
332
333	if (cmsg != NULL || i < nfds) {
334		int fd;
335
336		/*
337		 * We need to close all received descriptors, even if we have
338		 * different control message (eg. SCM_CREDS) in between.
339		 */
340		for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
341		    cmsg = CMSG_NXTHDR(&msg, cmsg)) {
342			fd = msghdr_get_fd(cmsg);
343			if (fd >= 0)
344				close(fd);
345		}
346		errno = EINVAL;
347		goto end;
348	}
349
350	ret = 0;
351end:
352	serrno = errno;
353	free(msg.msg_control);
354	errno = serrno;
355	return (ret);
356}
357
358int
359fd_recv(int sock, int *fds, size_t nfds)
360{
361	unsigned int i, step, j;
362	int ret, serrno;
363
364	if (nfds == 0 || fds == NULL) {
365		errno = EINVAL;
366		return (-1);
367	}
368
369	ret = i = step = 0;
370	while (i < nfds) {
371		if (PKG_MAX_SIZE < nfds - i)
372			step = PKG_MAX_SIZE;
373		else
374			step = nfds - i;
375		ret = fd_package_recv(sock, fds + i, step);
376		if (ret != 0) {
377			/* Close all received descriptors. */
378			serrno = errno;
379			for (j = 0; j < i; j++)
380				close(fds[j]);
381			errno = serrno;
382			break;
383		}
384		i += step;
385	}
386
387	return (ret);
388}
389
390int
391fd_send(int sock, const int *fds, size_t nfds)
392{
393	unsigned int i, step;
394	int ret;
395
396	if (nfds == 0 || fds == NULL) {
397		errno = EINVAL;
398		return (-1);
399	}
400
401	ret = i = step = 0;
402	while (i < nfds) {
403		if (PKG_MAX_SIZE < nfds - i)
404			step = PKG_MAX_SIZE;
405		else
406			step = nfds - i;
407		ret = fd_package_send(sock, fds + i, step);
408		if (ret != 0)
409			break;
410		i += step;
411	}
412
413	return (ret);
414}
415
416int
417buf_send(int sock, void *buf, size_t size)
418{
419	ssize_t done;
420	unsigned char *ptr;
421
422	PJDLOG_ASSERT(sock >= 0);
423	PJDLOG_ASSERT(size > 0);
424	PJDLOG_ASSERT(buf != NULL);
425
426	ptr = buf;
427	do {
428		fd_wait(sock, false);
429		done = send(sock, ptr, size, 0);
430		if (done == -1) {
431			if (errno == EINTR)
432				continue;
433			return (-1);
434		} else if (done == 0) {
435			errno = ENOTCONN;
436			return (-1);
437		}
438		size -= done;
439		ptr += done;
440	} while (size > 0);
441
442	return (0);
443}
444
445int
446buf_recv(int sock, void *buf, size_t size)
447{
448	ssize_t done;
449	unsigned char *ptr;
450
451	PJDLOG_ASSERT(sock >= 0);
452	PJDLOG_ASSERT(buf != NULL);
453
454	ptr = buf;
455	while (size > 0) {
456		fd_wait(sock, true);
457		done = recv(sock, ptr, size, 0);
458		if (done == -1) {
459			if (errno == EINTR)
460				continue;
461			return (-1);
462		} else if (done == 0) {
463			errno = ENOTCONN;
464			return (-1);
465		}
466		size -= done;
467		ptr += done;
468	}
469
470	return (0);
471}
472