1/*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2013 The FreeBSD Foundation
5 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6 * All rights reserved.
7 *
8 * This software was developed by Pawel Jakub Dawidek under sponsorship from
9 * the FreeBSD Foundation.
10 *
11 * Redistribution and use in source and binary forms, with or without
12 * modification, are permitted provided that the following conditions
13 * are met:
14 * 1. Redistributions of source code must retain the above copyright
15 *    notice, this list of conditions and the following disclaimer.
16 * 2. Redistributions in binary form must reproduce the above copyright
17 *    notice, this list of conditions and the following disclaimer in the
18 *    documentation and/or other materials provided with the distribution.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30 * SUCH DAMAGE.
31 */
32
33#include <sys/cdefs.h>
34__FBSDID("$FreeBSD$");
35
36#include <sys/param.h>
37#include <sys/socket.h>
38#include <sys/select.h>
39
40#include <errno.h>
41#include <fcntl.h>
42#include <stdbool.h>
43#include <stdint.h>
44#include <stdlib.h>
45#include <string.h>
46#include <unistd.h>
47
48#ifdef HAVE_PJDLOG
49#include <pjdlog.h>
50#endif
51
52#include "common_impl.h"
53#include "msgio.h"
54
55#ifndef	HAVE_PJDLOG
56#include <assert.h>
57#define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
58#define	PJDLOG_RASSERT(expr, ...)	assert(expr)
59#define	PJDLOG_ABORT(...)		abort()
60#endif
61
62#ifdef __linux__
63/* Linux: arbitrary size, but must be lower than SCM_MAX_FD. */
64#define	PKG_MAX_SIZE	((64U - 1) * CMSG_SPACE(sizeof(int)))
65#else
66/*
67 * To work around limitations in 32-bit emulation on 64-bit kernels, use a
68 * machine-independent limit on the number of FDs per message.  Each control
69 * message contains 1 FD and requires 12 bytes for the header, 4 pad bytes,
70 * 4 bytes for the descriptor, and another 4 pad bytes.
71 */
72#define	PKG_MAX_SIZE	(MCLBYTES / 24)
73#endif
74
75static int
76msghdr_add_fd(struct cmsghdr *cmsg, int fd)
77{
78
79	PJDLOG_ASSERT(fd >= 0);
80
81	cmsg->cmsg_level = SOL_SOCKET;
82	cmsg->cmsg_type = SCM_RIGHTS;
83	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
84	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
85
86	return (0);
87}
88
89static void
90fd_wait(int fd, bool doread)
91{
92	fd_set fds;
93
94	PJDLOG_ASSERT(fd >= 0);
95
96	FD_ZERO(&fds);
97	FD_SET(fd, &fds);
98	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
99	    NULL, NULL);
100}
101
102static int
103msg_recv(int sock, struct msghdr *msg)
104{
105	int flags;
106
107	PJDLOG_ASSERT(sock >= 0);
108
109#ifdef MSG_CMSG_CLOEXEC
110	flags = MSG_CMSG_CLOEXEC;
111#else
112	flags = 0;
113#endif
114
115	for (;;) {
116		fd_wait(sock, true);
117		if (recvmsg(sock, msg, flags) == -1) {
118			if (errno == EINTR)
119				continue;
120			return (-1);
121		}
122		break;
123	}
124
125	return (0);
126}
127
128static int
129msg_send(int sock, const struct msghdr *msg)
130{
131
132	PJDLOG_ASSERT(sock >= 0);
133
134	for (;;) {
135		fd_wait(sock, false);
136		if (sendmsg(sock, msg, 0) == -1) {
137			if (errno == EINTR)
138				continue;
139			return (-1);
140		}
141		break;
142	}
143
144	return (0);
145}
146
147#ifdef __FreeBSD__
148int
149cred_send(int sock)
150{
151	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
152	struct msghdr msg;
153	struct cmsghdr *cmsg;
154	struct iovec iov;
155	uint8_t dummy;
156
157	bzero(credbuf, sizeof(credbuf));
158	bzero(&msg, sizeof(msg));
159	bzero(&iov, sizeof(iov));
160
161	/*
162	 * XXX: We send one byte along with the control message, because
163	 *      setting msg_iov to NULL only works if this is the first
164	 *      packet send over the socket. Once we send some data we
165	 *      won't be able to send credentials anymore. This is most
166	 *      likely a kernel bug.
167	 */
168	dummy = 0;
169	iov.iov_base = &dummy;
170	iov.iov_len = sizeof(dummy);
171
172	msg.msg_iov = &iov;
173	msg.msg_iovlen = 1;
174	msg.msg_control = credbuf;
175	msg.msg_controllen = sizeof(credbuf);
176
177	cmsg = CMSG_FIRSTHDR(&msg);
178	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
179	cmsg->cmsg_level = SOL_SOCKET;
180	cmsg->cmsg_type = SCM_CREDS;
181
182	if (msg_send(sock, &msg) == -1)
183		return (-1);
184
185	return (0);
186}
187
188int
189cred_recv(int sock, struct cmsgcred *cred)
190{
191	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
192	struct msghdr msg;
193	struct cmsghdr *cmsg;
194	struct iovec iov;
195	uint8_t dummy;
196
197	bzero(credbuf, sizeof(credbuf));
198	bzero(&msg, sizeof(msg));
199	bzero(&iov, sizeof(iov));
200
201	iov.iov_base = &dummy;
202	iov.iov_len = sizeof(dummy);
203
204	msg.msg_iov = &iov;
205	msg.msg_iovlen = 1;
206	msg.msg_control = credbuf;
207	msg.msg_controllen = sizeof(credbuf);
208
209	if (msg_recv(sock, &msg) == -1)
210		return (-1);
211
212	cmsg = CMSG_FIRSTHDR(&msg);
213	if (cmsg == NULL ||
214	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
215	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
216		errno = EINVAL;
217		return (-1);
218	}
219	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
220
221	return (0);
222}
223#endif
224
225static int
226fd_package_send(int sock, const int *fds, size_t nfds)
227{
228	struct msghdr msg;
229	struct cmsghdr *cmsg;
230	struct iovec iov;
231	unsigned int i;
232	int serrno, ret;
233	uint8_t dummy;
234
235	PJDLOG_ASSERT(sock >= 0);
236	PJDLOG_ASSERT(fds != NULL);
237	PJDLOG_ASSERT(nfds > 0);
238
239	bzero(&msg, sizeof(msg));
240
241	/*
242	 * XXX: Look into cred_send function for more details.
243	 */
244	dummy = 0;
245	iov.iov_base = &dummy;
246	iov.iov_len = sizeof(dummy);
247
248	msg.msg_iov = &iov;
249	msg.msg_iovlen = 1;
250	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
251	msg.msg_control = calloc(1, msg.msg_controllen);
252	if (msg.msg_control == NULL)
253		return (-1);
254
255	ret = -1;
256
257	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
258	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
259		if (msghdr_add_fd(cmsg, fds[i]) == -1)
260			goto end;
261	}
262
263	if (msg_send(sock, &msg) == -1)
264		goto end;
265
266	ret = 0;
267end:
268	serrno = errno;
269	free(msg.msg_control);
270	errno = serrno;
271	return (ret);
272}
273
274static int
275fd_package_recv(int sock, int *fds, size_t nfds)
276{
277	struct msghdr msg;
278	struct cmsghdr *cmsg;
279	unsigned int i;
280	int serrno, ret;
281	struct iovec iov;
282	uint8_t dummy;
283
284	PJDLOG_ASSERT(sock >= 0);
285	PJDLOG_ASSERT(nfds > 0);
286	PJDLOG_ASSERT(fds != NULL);
287
288	bzero(&msg, sizeof(msg));
289	bzero(&iov, sizeof(iov));
290
291	/*
292	 * XXX: Look into cred_send function for more details.
293	 */
294	iov.iov_base = &dummy;
295	iov.iov_len = sizeof(dummy);
296
297	msg.msg_iov = &iov;
298	msg.msg_iovlen = 1;
299	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
300	msg.msg_control = calloc(1, msg.msg_controllen);
301	if (msg.msg_control == NULL)
302		return (-1);
303
304	ret = -1;
305
306	if (msg_recv(sock, &msg) == -1)
307		goto end;
308
309	i = 0;
310	cmsg = CMSG_FIRSTHDR(&msg);
311	while (cmsg && i < nfds) {
312		unsigned int n;
313
314		if (cmsg->cmsg_level != SOL_SOCKET ||
315		    cmsg->cmsg_type != SCM_RIGHTS) {
316			errno = EINVAL;
317			break;
318		}
319		n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
320		if (i + n > nfds) {
321			errno = EINVAL;
322			break;
323		}
324		bcopy(CMSG_DATA(cmsg), fds + i, sizeof(int) * n);
325		cmsg = CMSG_NXTHDR(&msg, cmsg);
326		i += n;
327	}
328
329	if (cmsg != NULL || i < nfds) {
330		unsigned int last;
331
332		/*
333		 * We need to close all received descriptors, even if we have
334		 * different control message (eg. SCM_CREDS) in between.
335		 */
336		last = i;
337		for (i = 0; i < last; i++) {
338			if (fds[i] >= 0) {
339				close(fds[i]);
340			}
341		}
342		errno = EINVAL;
343		goto end;
344	}
345
346#ifndef MSG_CMSG_CLOEXEC
347	/*
348	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
349	 * close-on-exec flag atomically, but we still want to set it for
350	 * consistency.
351	 */
352	for (i = 0; i < nfds; i++) {
353		(void) fcntl(fds[i], F_SETFD, FD_CLOEXEC);
354	}
355#endif
356
357	ret = 0;
358end:
359	serrno = errno;
360	free(msg.msg_control);
361	errno = serrno;
362	return (ret);
363}
364
365int
366fd_recv(int sock, int *fds, size_t nfds)
367{
368	unsigned int i, step, j;
369	int ret, serrno;
370
371	if (nfds == 0 || fds == NULL) {
372		errno = EINVAL;
373		return (-1);
374	}
375
376	ret = i = step = 0;
377	while (i < nfds) {
378		if (PKG_MAX_SIZE < nfds - i)
379			step = PKG_MAX_SIZE;
380		else
381			step = nfds - i;
382		ret = fd_package_recv(sock, fds + i, step);
383		if (ret != 0) {
384			/* Close all received descriptors. */
385			serrno = errno;
386			for (j = 0; j < i; j++)
387				close(fds[j]);
388			errno = serrno;
389			break;
390		}
391		i += step;
392	}
393
394	return (ret);
395}
396
397int
398fd_send(int sock, const int *fds, size_t nfds)
399{
400	unsigned int i, step;
401	int ret;
402
403	if (nfds == 0 || fds == NULL) {
404		errno = EINVAL;
405		return (-1);
406	}
407
408	ret = i = step = 0;
409	while (i < nfds) {
410		if (PKG_MAX_SIZE < nfds - i)
411			step = PKG_MAX_SIZE;
412		else
413			step = nfds - i;
414		ret = fd_package_send(sock, fds + i, step);
415		if (ret != 0)
416			break;
417		i += step;
418	}
419
420	return (ret);
421}
422
423int
424buf_send(int sock, void *buf, size_t size)
425{
426	ssize_t done;
427	unsigned char *ptr;
428
429	PJDLOG_ASSERT(sock >= 0);
430	PJDLOG_ASSERT(size > 0);
431	PJDLOG_ASSERT(buf != NULL);
432
433	ptr = buf;
434	do {
435		fd_wait(sock, false);
436		done = send(sock, ptr, size, 0);
437		if (done == -1) {
438			if (errno == EINTR)
439				continue;
440			return (-1);
441		} else if (done == 0) {
442			errno = ENOTCONN;
443			return (-1);
444		}
445		size -= done;
446		ptr += done;
447	} while (size > 0);
448
449	return (0);
450}
451
452int
453buf_recv(int sock, void *buf, size_t size)
454{
455	ssize_t done;
456	unsigned char *ptr;
457
458	PJDLOG_ASSERT(sock >= 0);
459	PJDLOG_ASSERT(buf != NULL);
460
461	ptr = buf;
462	while (size > 0) {
463		fd_wait(sock, true);
464		done = recv(sock, ptr, size, 0);
465		if (done == -1) {
466			if (errno == EINTR)
467				continue;
468			return (-1);
469		} else if (done == 0) {
470			errno = ENOTCONN;
471			return (-1);
472		}
473		size -= done;
474		ptr += done;
475	}
476
477	return (0);
478}
479