1/*	$NetBSD: unfdpass.c,v 1.12 2021/08/08 20:54:48 nia Exp $	*/
2
3/*-
4 * Copyright (c) 1998 The NetBSD Foundation, Inc.
5 * All rights reserved.
6 *
7 * This code is derived from software contributed to The NetBSD Foundation
8 * by Jason R. Thorpe of the Numerical Aerospace Simulation Facility,
9 * NASA Ames Research Center.
10 *
11 * Redistribution and use in source and binary forms, with or without
12 * modification, are permitted provided that the following conditions
13 * are met:
14 * 1. Redistributions of source code must retain the above copyright
15 *    notice, this list of conditions and the following disclaimer.
16 * 2. Redistributions in binary form must reproduce the above copyright
17 *    notice, this list of conditions and the following disclaimer in the
18 *    documentation and/or other materials provided with the distribution.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
21 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
23 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
24 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30 * POSSIBILITY OF SUCH DAMAGE.
31 */
32
33/*
34 * Test passing of file descriptors and credentials over Unix domain sockets.
35 */
36
37#include <sys/param.h>
38#include <sys/socket.h>
39#include <sys/time.h>
40#include <sys/wait.h>
41#include <sys/un.h>
42#include <sys/uio.h>
43#include <sys/stat.h>
44
45#include <err.h>
46#include <errno.h>
47#include <fcntl.h>
48#include <signal.h>
49#include <stdio.h>
50#include <string.h>
51#include <stdlib.h>
52#include <unistd.h>
53
54#define	SOCK_NAME	"test-sock"
55
56int	main(int, char *[]);
57void	child(void);
58void	catch_sigchld(int);
59void	usage(char *progname);
60
61#define	FILE_SIZE	128
62#define	MSG_SIZE	-1
63#define	NFILES		24
64
65#define	FDCM_DATASIZE	(sizeof(int) * NFILES)
66#define	CRCM_DATASIZE	(SOCKCREDSIZE(NGROUPS))
67
68#define	MESSAGE_SIZE	(CMSG_SPACE(FDCM_DATASIZE) +			\
69			 CMSG_SPACE(CRCM_DATASIZE))
70
71int chroot_rcvr = 0;
72int pass_dir = 0;
73int pass_root_dir = 0;
74int exit_early = 0;
75int exit_later = 0;
76int pass_sock = 0;
77int make_pretzel = 0;
78
79/* ARGSUSED */
80int
81main(argc, argv)
82	int argc;
83	char *argv[];
84{
85#if MSG_SIZE >= 0
86	struct iovec iov;
87#endif
88	char *progname=argv[0];
89	struct msghdr msg;
90	int listensock, sock, fd, i;
91	char fname[16], buf[FILE_SIZE];
92	struct cmsghdr *cmp;
93	void *message;
94	int *files = NULL;
95	struct sockcred *sc = NULL;
96	struct sockaddr_un sun, csun;
97	socklen_t csunlen;
98	pid_t pid;
99	int ch;
100
101	message = malloc(CMSG_SPACE(MESSAGE_SIZE));
102	if (message == NULL)
103		err(1, "unable to malloc message buffer");
104	memset(message, 0, CMSG_SPACE(MESSAGE_SIZE));
105
106	while ((ch = getopt(argc, argv, "DESdepr")) != -1) {
107		switch(ch) {
108
109		case 'e':
110			exit_early++; /* test early GC */
111			break;
112
113		case 'E':
114			exit_later++; /* test later GC */
115			break;
116
117		case 'd':
118			pass_dir++;
119			break;
120
121		case 'D':
122			pass_dir++;
123			pass_root_dir++;
124			break;
125
126		case 'S':
127			pass_sock++;
128			break;
129
130		case 'r':
131			chroot_rcvr++;
132			break;
133
134		case 'p':
135			make_pretzel++;
136			break;
137
138		case '?':
139		default:
140			usage(progname);
141		}
142	}
143
144
145	/*
146	 * Create the test files.
147	 */
148	for (i = 0; i < NFILES; i++) {
149		(void) sprintf(fname, "file%d", i + 1);
150		if ((fd = open(fname, O_WRONLY|O_CREAT|O_TRUNC, 0666)) == -1)
151			err(1, "open %s", fname);
152		(void) sprintf(buf, "This is file %d.\n", i + 1);
153		if (write(fd, buf, strlen(buf)) != strlen(buf))
154			err(1, "write %s", fname);
155		(void) close(fd);
156	}
157
158	/*
159	 * Create the listen socket.
160	 */
161	if ((listensock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
162		err(1, "socket");
163
164	(void) unlink(SOCK_NAME);
165	(void) memset(&sun, 0, sizeof(sun));
166	sun.sun_family = AF_LOCAL;
167	(void) strcpy(sun.sun_path, SOCK_NAME);
168	sun.sun_len = SUN_LEN(&sun);
169
170	i = 1;
171	if (setsockopt(listensock, SOL_LOCAL, LOCAL_CREDS, &i, sizeof(i)) == -1)
172		err(1, "setsockopt");
173
174	if (bind(listensock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
175		err(1, "bind");
176
177	if (listen(listensock, 1) == -1)
178		err(1, "listen");
179
180	/*
181	 * Create the sender.
182	 */
183	(void) signal(SIGCHLD, catch_sigchld);
184	pid = fork();
185	switch (pid) {
186	case -1:
187		err(1, "fork");
188		/* NOTREACHED */
189
190	case 0:
191		child();
192		/* NOTREACHED */
193	}
194
195	if (exit_early)
196		exit(0);
197
198	if (chroot_rcvr &&
199	    ((chroot(".") < 0)))
200		err(1, "chroot");
201
202	/*
203	 * Wait for the sender to connect.
204	 */
205	csunlen = sizeof(csun);
206	if ((sock = accept(listensock, (struct sockaddr *)&csun,
207	    &csunlen)) == -1)
208		err(1, "accept");
209
210	/*
211	 * Give sender a chance to run.  We will get going again
212	 * once the SIGCHLD arrives.
213	 */
214	(void) sleep(10);
215
216	if (exit_later)
217		exit(0);
218
219	/*
220	 * Grab the descriptors and credentials passed to us.
221	 */
222
223	/* Expect 2 messages; descriptors and creds. */
224	do {
225		(void) memset(&msg, 0, sizeof(msg));
226		msg.msg_control = message;
227		msg.msg_controllen = MESSAGE_SIZE;
228#if MSG_SIZE >= 0
229		iov.iov_base = buf;
230		iov.iov_len = MSG_SIZE;
231		msg.msg_iov = &iov;
232		msg.msg_iovlen = 1;
233#endif
234
235		if (recvmsg(sock, &msg, 0) == -1)
236			err(1, "recvmsg");
237
238		(void) close(sock);
239		sock = -1;
240
241		if (msg.msg_controllen == 0)
242			errx(1, "no control messages received");
243
244		if (msg.msg_flags & MSG_CTRUNC)
245			errx(1, "lost control message data");
246
247		for (cmp = CMSG_FIRSTHDR(&msg); cmp != NULL;
248		     cmp = CMSG_NXTHDR(&msg, cmp)) {
249			if (cmp->cmsg_level != SOL_SOCKET)
250				errx(1, "bad control message level %d",
251				    cmp->cmsg_level);
252
253			switch (cmp->cmsg_type) {
254			case SCM_RIGHTS:
255				if (cmp->cmsg_len != CMSG_LEN(FDCM_DATASIZE))
256					errx(1, "bad fd control message "
257					    "length %d", cmp->cmsg_len);
258
259				files = (int *)CMSG_DATA(cmp);
260				break;
261
262			case SCM_CREDS:
263				if (cmp->cmsg_len < CMSG_LEN(SOCKCREDSIZE(1)))
264					errx(1, "bad cred control message "
265					    "length %d", cmp->cmsg_len);
266
267				sc = (struct sockcred *)CMSG_DATA(cmp);
268				break;
269
270			default:
271				errx(1, "unexpected control message");
272				/* NOTREACHED */
273			}
274		}
275
276		/*
277		 * Read the files and print their contents.
278		 */
279		if (files == NULL)
280			warnx("didn't get fd control message");
281		else {
282			for (i = 0; i < NFILES; i++) {
283				struct stat st;
284				(void) memset(buf, 0, sizeof(buf));
285				fstat(files[i], &st);
286				if (S_ISDIR(st.st_mode)) {
287					printf("file %d is a directory\n", i+1);
288				} else if (S_ISSOCK(st.st_mode)) {
289					printf("file %d is a socket\n", i+1);
290					sock = files[i];
291				} else {
292					int c;
293					c = read (files[i], buf, sizeof(buf));
294					if (c < 0)
295						err(1, "read file %d", i + 1);
296					else if (c == 0)
297						printf("[eof on %d]\n", i + 1);
298					else
299						printf("%s", buf);
300				}
301			}
302		}
303		/*
304		 * Double-check credentials.
305		 */
306		if (sc == NULL)
307			warnx("didn't get cred control message");
308		else {
309			if (sc->sc_uid == getuid() &&
310			    sc->sc_euid == geteuid() &&
311			    sc->sc_gid == getgid() &&
312			    sc->sc_egid == getegid())
313				printf("Credentials match.\n");
314			else
315				printf("Credentials do NOT match.\n");
316		}
317	} while (sock != -1);
318
319	/*
320	 * All done!
321	 */
322	exit(0);
323}
324
325void
326usage(progname)
327	char *progname;
328{
329	fprintf(stderr, "usage: %s [-derDES]\n", progname);
330	exit(1);
331}
332
333void
334catch_sigchld(sig)
335	int sig;
336{
337	int status;
338
339	(void) wait(&status);
340}
341
342void
343child()
344{
345#if MSG_SIZE >= 0
346	struct iovec iov;
347#endif
348	struct msghdr msg;
349	char fname[16];
350	struct cmsghdr *cmp;
351	void *fdcm;
352	int i, fd, sock, nfd, *files;
353	struct sockaddr_un sun;
354	int spair[2];
355
356	fdcm = malloc(CMSG_SPACE(FDCM_DATASIZE));
357	if (fdcm == NULL)
358		err(1, "unable to malloc fd control message");
359	memset(fdcm, 0, CMSG_SPACE(FDCM_DATASIZE));
360
361	cmp = fdcm;
362	files = (int *)CMSG_DATA(fdcm);
363
364	/*
365	 * Create socket and connect to the receiver.
366	 */
367	if ((sock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
368		errx(1, "child socket");
369
370	(void) memset(&sun, 0, sizeof(sun));
371	sun.sun_family = AF_LOCAL;
372	(void) strcpy(sun.sun_path, SOCK_NAME);
373	sun.sun_len = SUN_LEN(&sun);
374
375	if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
376		err(1, "child connect");
377
378	nfd = NFILES;
379	i = 0;
380
381	if (pass_sock) {
382		files[i++] = sock;
383	}
384
385	if (pass_dir)
386		nfd--;
387
388	/*
389	 * Open the files again, and pass them to the child
390	 * over the socket.
391	 */
392
393	for (; i < nfd; i++) {
394		(void) sprintf(fname, "file%d", i + 1);
395		if ((fd = open(fname, O_RDONLY, 0666)) == -1)
396			err(1, "child open %s", fname);
397		files[i] = fd;
398	}
399
400	if (pass_dir) {
401		char *dirname = pass_root_dir ? "/" : ".";
402
403
404		if ((fd = open(dirname, O_RDONLY, 0)) == -1) {
405			err(1, "child open directory %s", dirname);
406		}
407		files[i] = fd;
408	}
409
410	(void) memset(&msg, 0, sizeof(msg));
411	msg.msg_control = fdcm;
412	msg.msg_controllen = CMSG_LEN(FDCM_DATASIZE);
413#if MSG_SIZE >= 0
414	iov.iov_base = buf;
415	iov.iov_len = MSG_SIZE;
416	msg.msg_iov = &iov;
417	msg.msg_iovlen = 1;
418#endif
419
420	cmp = CMSG_FIRSTHDR(&msg);
421	cmp->cmsg_len = CMSG_LEN(FDCM_DATASIZE);
422	cmp->cmsg_level = SOL_SOCKET;
423	cmp->cmsg_type = SCM_RIGHTS;
424
425	while (make_pretzel > 0) {
426		if (socketpair(PF_LOCAL, SOCK_STREAM, 0, spair) < 0)
427			err(1, "socketpair");
428
429		printf("send pretzel\n");
430		if (sendmsg(spair[0], &msg, 0) < 0)
431			err(1, "child prezel sendmsg");
432
433		close(files[0]);
434		close(files[1]);
435		files[0] = spair[0];
436		files[1] = spair[1];
437		make_pretzel--;
438	}
439
440	if (sendmsg(sock, &msg, 0) == -1)
441		err(1, "child sendmsg");
442
443	/*
444	 * All done!
445	 */
446	exit(0);
447}
448