1258065Spjd/*-
2258065Spjd * Copyright (c) 2013 The FreeBSD Foundation
3258065Spjd * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
4258065Spjd * All rights reserved.
5258065Spjd *
6258065Spjd * This software was developed by Pawel Jakub Dawidek under sponsorship from
7258065Spjd * the FreeBSD Foundation.
8258065Spjd *
9258065Spjd * Redistribution and use in source and binary forms, with or without
10258065Spjd * modification, are permitted provided that the following conditions
11258065Spjd * are met:
12258065Spjd * 1. Redistributions of source code must retain the above copyright
13258065Spjd *    notice, this list of conditions and the following disclaimer.
14258065Spjd * 2. Redistributions in binary form must reproduce the above copyright
15258065Spjd *    notice, this list of conditions and the following disclaimer in the
16258065Spjd *    documentation and/or other materials provided with the distribution.
17258065Spjd *
18258065Spjd * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
19258065Spjd * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20258065Spjd * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21258065Spjd * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
22258065Spjd * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23258065Spjd * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24258065Spjd * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25258065Spjd * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26258065Spjd * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27258065Spjd * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28258065Spjd * SUCH DAMAGE.
29258065Spjd */
30258065Spjd
31258065Spjd#include <sys/cdefs.h>
32258065Spjd__FBSDID("$FreeBSD: stable/11/lib/libnv/msgio.c 324831 2017-10-21 19:34:54Z oshogbo $");
33258065Spjd
34271578Spjd#include <sys/param.h>
35258065Spjd#include <sys/socket.h>
36258065Spjd
37258065Spjd#include <errno.h>
38258065Spjd#include <fcntl.h>
39258065Spjd#include <stdbool.h>
40258065Spjd#include <stdint.h>
41258065Spjd#include <stdlib.h>
42258065Spjd#include <string.h>
43258065Spjd#include <unistd.h>
44258065Spjd
45258065Spjd#ifdef HAVE_PJDLOG
46258065Spjd#include <pjdlog.h>
47258065Spjd#endif
48258065Spjd
49258065Spjd#include "common_impl.h"
50258065Spjd#include "msgio.h"
51258065Spjd
52258065Spjd#ifndef	HAVE_PJDLOG
53258065Spjd#include <assert.h>
54258065Spjd#define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
55258065Spjd#define	PJDLOG_RASSERT(expr, ...)	assert(expr)
56258065Spjd#define	PJDLOG_ABORT(...)		abort()
57258065Spjd#endif
58258065Spjd
59271578Spjd#define	PKG_MAX_SIZE	(MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
60271578Spjd
61258065Spjdstatic int
62258065Spjdmsghdr_add_fd(struct cmsghdr *cmsg, int fd)
63258065Spjd{
64258065Spjd
65258065Spjd	PJDLOG_ASSERT(fd >= 0);
66258065Spjd
67258065Spjd	if (!fd_is_valid(fd)) {
68258065Spjd		errno = EBADF;
69258065Spjd		return (-1);
70258065Spjd	}
71258065Spjd
72258065Spjd	cmsg->cmsg_level = SOL_SOCKET;
73258065Spjd	cmsg->cmsg_type = SCM_RIGHTS;
74258065Spjd	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
75258065Spjd	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
76258065Spjd
77258065Spjd	return (0);
78258065Spjd}
79258065Spjd
80258065Spjdstatic int
81258065Spjdmsghdr_get_fd(struct cmsghdr *cmsg)
82258065Spjd{
83258065Spjd	int fd;
84258065Spjd
85258065Spjd	if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
86258065Spjd	    cmsg->cmsg_type != SCM_RIGHTS ||
87258065Spjd	    cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
88258065Spjd		errno = EINVAL;
89258065Spjd		return (-1);
90258065Spjd	}
91258065Spjd
92258065Spjd	bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
93258065Spjd#ifndef MSG_CMSG_CLOEXEC
94258065Spjd	/*
95258065Spjd	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
96258065Spjd	 * close-on-exec flag atomically, but we still want to set it for
97258065Spjd	 * consistency.
98258065Spjd	 */
99258065Spjd	(void) fcntl(fd, F_SETFD, FD_CLOEXEC);
100258065Spjd#endif
101258065Spjd
102258065Spjd	return (fd);
103258065Spjd}
104258065Spjd
105258065Spjdstatic void
106258065Spjdfd_wait(int fd, bool doread)
107258065Spjd{
108258065Spjd	fd_set fds;
109258065Spjd
110258065Spjd	PJDLOG_ASSERT(fd >= 0);
111258065Spjd
112258065Spjd	FD_ZERO(&fds);
113258065Spjd	FD_SET(fd, &fds);
114258065Spjd	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
115258065Spjd	    NULL, NULL);
116258065Spjd}
117258065Spjd
118258065Spjdstatic int
119258065Spjdmsg_recv(int sock, struct msghdr *msg)
120258065Spjd{
121258065Spjd	int flags;
122258065Spjd
123258065Spjd	PJDLOG_ASSERT(sock >= 0);
124258065Spjd
125258065Spjd#ifdef MSG_CMSG_CLOEXEC
126258065Spjd	flags = MSG_CMSG_CLOEXEC;
127258065Spjd#else
128258065Spjd	flags = 0;
129258065Spjd#endif
130258065Spjd
131258065Spjd	for (;;) {
132258065Spjd		fd_wait(sock, true);
133258065Spjd		if (recvmsg(sock, msg, flags) == -1) {
134258065Spjd			if (errno == EINTR)
135258065Spjd				continue;
136258065Spjd			return (-1);
137258065Spjd		}
138258065Spjd		break;
139258065Spjd	}
140258065Spjd
141258065Spjd	return (0);
142258065Spjd}
143258065Spjd
144258065Spjdstatic int
145258065Spjdmsg_send(int sock, const struct msghdr *msg)
146258065Spjd{
147258065Spjd
148258065Spjd	PJDLOG_ASSERT(sock >= 0);
149258065Spjd
150258065Spjd	for (;;) {
151258065Spjd		fd_wait(sock, false);
152258065Spjd		if (sendmsg(sock, msg, 0) == -1) {
153258065Spjd			if (errno == EINTR)
154258065Spjd				continue;
155258065Spjd			return (-1);
156258065Spjd		}
157258065Spjd		break;
158258065Spjd	}
159258065Spjd
160258065Spjd	return (0);
161258065Spjd}
162258065Spjd
163258065Spjdint
164258065Spjdcred_send(int sock)
165258065Spjd{
166258065Spjd	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
167258065Spjd	struct msghdr msg;
168258065Spjd	struct cmsghdr *cmsg;
169258065Spjd	struct iovec iov;
170258065Spjd	uint8_t dummy;
171258065Spjd
172258065Spjd	bzero(credbuf, sizeof(credbuf));
173258065Spjd	bzero(&msg, sizeof(msg));
174258065Spjd	bzero(&iov, sizeof(iov));
175258065Spjd
176258065Spjd	/*
177258065Spjd	 * XXX: We send one byte along with the control message, because
178258065Spjd	 *      setting msg_iov to NULL only works if this is the first
179258065Spjd	 *      packet send over the socket. Once we send some data we
180258065Spjd	 *      won't be able to send credentials anymore. This is most
181258065Spjd	 *      likely a kernel bug.
182258065Spjd	 */
183258065Spjd	dummy = 0;
184258065Spjd	iov.iov_base = &dummy;
185258065Spjd	iov.iov_len = sizeof(dummy);
186258065Spjd
187258065Spjd	msg.msg_iov = &iov;
188258065Spjd	msg.msg_iovlen = 1;
189258065Spjd	msg.msg_control = credbuf;
190258065Spjd	msg.msg_controllen = sizeof(credbuf);
191258065Spjd
192258065Spjd	cmsg = CMSG_FIRSTHDR(&msg);
193258065Spjd	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
194258065Spjd	cmsg->cmsg_level = SOL_SOCKET;
195258065Spjd	cmsg->cmsg_type = SCM_CREDS;
196258065Spjd
197258065Spjd	if (msg_send(sock, &msg) == -1)
198258065Spjd		return (-1);
199258065Spjd
200258065Spjd	return (0);
201258065Spjd}
202258065Spjd
203258065Spjdint
204258065Spjdcred_recv(int sock, struct cmsgcred *cred)
205258065Spjd{
206258065Spjd	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
207258065Spjd	struct msghdr msg;
208258065Spjd	struct cmsghdr *cmsg;
209258065Spjd	struct iovec iov;
210258065Spjd	uint8_t dummy;
211258065Spjd
212258065Spjd	bzero(credbuf, sizeof(credbuf));
213258065Spjd	bzero(&msg, sizeof(msg));
214258065Spjd	bzero(&iov, sizeof(iov));
215258065Spjd
216258065Spjd	iov.iov_base = &dummy;
217258065Spjd	iov.iov_len = sizeof(dummy);
218258065Spjd
219258065Spjd	msg.msg_iov = &iov;
220258065Spjd	msg.msg_iovlen = 1;
221258065Spjd	msg.msg_control = credbuf;
222258065Spjd	msg.msg_controllen = sizeof(credbuf);
223258065Spjd
224258065Spjd	if (msg_recv(sock, &msg) == -1)
225258065Spjd		return (-1);
226258065Spjd
227258065Spjd	cmsg = CMSG_FIRSTHDR(&msg);
228258065Spjd	if (cmsg == NULL ||
229258065Spjd	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
230258065Spjd	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
231258065Spjd		errno = EINVAL;
232258065Spjd		return (-1);
233258065Spjd	}
234258065Spjd	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
235258065Spjd
236258065Spjd	return (0);
237258065Spjd}
238258065Spjd
239271578Spjdstatic int
240271578Spjdfd_package_send(int sock, const int *fds, size_t nfds)
241258065Spjd{
242258065Spjd	struct msghdr msg;
243258065Spjd	struct cmsghdr *cmsg;
244271578Spjd	struct iovec iov;
245258065Spjd	unsigned int i;
246258065Spjd	int serrno, ret;
247271578Spjd	uint8_t dummy;
248258065Spjd
249271578Spjd	PJDLOG_ASSERT(sock >= 0);
250271578Spjd	PJDLOG_ASSERT(fds != NULL);
251271578Spjd	PJDLOG_ASSERT(nfds > 0);
252258065Spjd
253258065Spjd	bzero(&msg, sizeof(msg));
254271578Spjd
255271578Spjd	/*
256271578Spjd	 * XXX: Look into cred_send function for more details.
257271578Spjd	 */
258271578Spjd	dummy = 0;
259271578Spjd	iov.iov_base = &dummy;
260271578Spjd	iov.iov_len = sizeof(dummy);
261271578Spjd
262271578Spjd	msg.msg_iov = &iov;
263271578Spjd	msg.msg_iovlen = 1;
264258065Spjd	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
265258065Spjd	msg.msg_control = calloc(1, msg.msg_controllen);
266258065Spjd	if (msg.msg_control == NULL)
267258065Spjd		return (-1);
268258065Spjd
269258065Spjd	ret = -1;
270258065Spjd
271258065Spjd	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
272258065Spjd	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
273258065Spjd		if (msghdr_add_fd(cmsg, fds[i]) == -1)
274258065Spjd			goto end;
275258065Spjd	}
276258065Spjd
277258065Spjd	if (msg_send(sock, &msg) == -1)
278258065Spjd		goto end;
279258065Spjd
280258065Spjd	ret = 0;
281258065Spjdend:
282258065Spjd	serrno = errno;
283258065Spjd	free(msg.msg_control);
284258065Spjd	errno = serrno;
285258065Spjd	return (ret);
286258065Spjd}
287258065Spjd
288271578Spjdstatic int
289271578Spjdfd_package_recv(int sock, int *fds, size_t nfds)
290258065Spjd{
291258065Spjd	struct msghdr msg;
292258065Spjd	struct cmsghdr *cmsg;
293258065Spjd	unsigned int i;
294258065Spjd	int serrno, ret;
295271578Spjd	struct iovec iov;
296271578Spjd	uint8_t dummy;
297258065Spjd
298271578Spjd	PJDLOG_ASSERT(sock >= 0);
299271578Spjd	PJDLOG_ASSERT(nfds > 0);
300271578Spjd	PJDLOG_ASSERT(fds != NULL);
301258065Spjd
302258065Spjd	bzero(&msg, sizeof(msg));
303271578Spjd	bzero(&iov, sizeof(iov));
304271578Spjd
305271578Spjd	/*
306271578Spjd	 * XXX: Look into cred_send function for more details.
307271578Spjd	 */
308271578Spjd	iov.iov_base = &dummy;
309271578Spjd	iov.iov_len = sizeof(dummy);
310271578Spjd
311271578Spjd	msg.msg_iov = &iov;
312271578Spjd	msg.msg_iovlen = 1;
313258065Spjd	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
314258065Spjd	msg.msg_control = calloc(1, msg.msg_controllen);
315258065Spjd	if (msg.msg_control == NULL)
316258065Spjd		return (-1);
317258065Spjd
318258065Spjd	ret = -1;
319258065Spjd
320258065Spjd	if (msg_recv(sock, &msg) == -1)
321258065Spjd		goto end;
322258065Spjd
323258065Spjd	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
324258065Spjd	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
325258065Spjd		fds[i] = msghdr_get_fd(cmsg);
326258065Spjd		if (fds[i] < 0)
327258065Spjd			break;
328258065Spjd	}
329258065Spjd
330258065Spjd	if (cmsg != NULL || i < nfds) {
331258065Spjd		int fd;
332258065Spjd
333258065Spjd		/*
334258065Spjd		 * We need to close all received descriptors, even if we have
335258065Spjd		 * different control message (eg. SCM_CREDS) in between.
336258065Spjd		 */
337258065Spjd		for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
338258065Spjd		    cmsg = CMSG_NXTHDR(&msg, cmsg)) {
339258065Spjd			fd = msghdr_get_fd(cmsg);
340258065Spjd			if (fd >= 0)
341258065Spjd				close(fd);
342258065Spjd		}
343258065Spjd		errno = EINVAL;
344258065Spjd		goto end;
345258065Spjd	}
346258065Spjd
347258065Spjd	ret = 0;
348258065Spjdend:
349258065Spjd	serrno = errno;
350258065Spjd	free(msg.msg_control);
351258065Spjd	errno = serrno;
352258065Spjd	return (ret);
353258065Spjd}
354258065Spjd
355258065Spjdint
356271578Spjdfd_recv(int sock, int *fds, size_t nfds)
357271578Spjd{
358271578Spjd	unsigned int i, step, j;
359271578Spjd	int ret, serrno;
360271578Spjd
361271578Spjd	if (nfds == 0 || fds == NULL) {
362271578Spjd		errno = EINVAL;
363271578Spjd		return (-1);
364271578Spjd	}
365271578Spjd
366271578Spjd	ret = i = step = 0;
367271578Spjd	while (i < nfds) {
368271578Spjd		if (PKG_MAX_SIZE < nfds - i)
369271578Spjd			step = PKG_MAX_SIZE;
370271578Spjd		else
371271578Spjd			step = nfds - i;
372271578Spjd		ret = fd_package_recv(sock, fds + i, step);
373271578Spjd		if (ret != 0) {
374271578Spjd			/* Close all received descriptors. */
375271578Spjd			serrno = errno;
376271578Spjd			for (j = 0; j < i; j++)
377271578Spjd				close(fds[j]);
378271578Spjd			errno = serrno;
379271578Spjd			break;
380271578Spjd		}
381271578Spjd		i += step;
382271578Spjd	}
383271578Spjd
384271578Spjd	return (ret);
385271578Spjd}
386271578Spjd
387271578Spjdint
388271578Spjdfd_send(int sock, const int *fds, size_t nfds)
389271578Spjd{
390271578Spjd	unsigned int i, step;
391271578Spjd	int ret;
392271578Spjd
393271578Spjd	if (nfds == 0 || fds == NULL) {
394271578Spjd		errno = EINVAL;
395271578Spjd		return (-1);
396271578Spjd	}
397271578Spjd
398271578Spjd	ret = i = step = 0;
399271578Spjd	while (i < nfds) {
400271578Spjd		if (PKG_MAX_SIZE < nfds - i)
401271578Spjd			step = PKG_MAX_SIZE;
402271578Spjd		else
403271578Spjd			step = nfds - i;
404271578Spjd		ret = fd_package_send(sock, fds + i, step);
405271578Spjd		if (ret != 0)
406271578Spjd			break;
407271578Spjd		i += step;
408271578Spjd	}
409271578Spjd
410271578Spjd	return (ret);
411271578Spjd}
412271578Spjd
413271578Spjdint
414258065Spjdbuf_send(int sock, void *buf, size_t size)
415258065Spjd{
416258065Spjd	ssize_t done;
417258065Spjd	unsigned char *ptr;
418258065Spjd
419261408Spjd	PJDLOG_ASSERT(sock >= 0);
420261408Spjd	PJDLOG_ASSERT(size > 0);
421261408Spjd	PJDLOG_ASSERT(buf != NULL);
422261408Spjd
423258065Spjd	ptr = buf;
424258065Spjd	do {
425258065Spjd		fd_wait(sock, false);
426258065Spjd		done = send(sock, ptr, size, 0);
427258065Spjd		if (done == -1) {
428258065Spjd			if (errno == EINTR)
429258065Spjd				continue;
430258065Spjd			return (-1);
431258065Spjd		} else if (done == 0) {
432258065Spjd			errno = ENOTCONN;
433258065Spjd			return (-1);
434258065Spjd		}
435258065Spjd		size -= done;
436258065Spjd		ptr += done;
437258065Spjd	} while (size > 0);
438258065Spjd
439258065Spjd	return (0);
440258065Spjd}
441258065Spjd
442258065Spjdint
443258065Spjdbuf_recv(int sock, void *buf, size_t size)
444258065Spjd{
445258065Spjd	ssize_t done;
446258065Spjd	unsigned char *ptr;
447258065Spjd
448261408Spjd	PJDLOG_ASSERT(sock >= 0);
449261408Spjd	PJDLOG_ASSERT(buf != NULL);
450261408Spjd
451258065Spjd	ptr = buf;
452261407Spjd	while (size > 0) {
453258065Spjd		fd_wait(sock, true);
454258065Spjd		done = recv(sock, ptr, size, 0);
455258065Spjd		if (done == -1) {
456258065Spjd			if (errno == EINTR)
457258065Spjd				continue;
458258065Spjd			return (-1);
459258065Spjd		} else if (done == 0) {
460258065Spjd			errno = ENOTCONN;
461258065Spjd			return (-1);
462258065Spjd		}
463258065Spjd		size -= done;
464258065Spjd		ptr += done;
465261407Spjd	}
466258065Spjd
467258065Spjd	return (0);
468258065Spjd}
469