1/*	$NetBSD: unfdpass.c,v 1.9 2008/02/29 16:28:12 ad Exp $	*/
2
3/*-
4 * Copyright (c) 1998 The NetBSD Foundation, Inc.
5 * All rights reserved.
6 *
7 * This code is derived from software contributed to The NetBSD Foundation
8 * by Jason R. Thorpe of the Numerical Aerospace Simulation Facility,
9 * NASA Ames Research Center.
10 *
11 * Redistribution and use in source and binary forms, with or without
12 * modification, are permitted provided that the following conditions
13 * are met:
14 * 1. Redistributions of source code must retain the above copyright
15 *    notice, this list of conditions and the following disclaimer.
16 * 2. Redistributions in binary form must reproduce the above copyright
17 *    notice, this list of conditions and the following disclaimer in the
18 *    documentation and/or other materials provided with the distribution.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
21 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
23 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
24 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30 * POSSIBILITY OF SUCH DAMAGE.
31 */
32
33/*
34 * Test passing of file descriptors and credentials over Unix domain sockets.
35 */
36
37#include <sys/param.h>
38#include <sys/socket.h>
39#include <sys/time.h>
40#include <sys/wait.h>
41#include <sys/un.h>
42#include <sys/uio.h>
43
44#include <err.h>
45#include <errno.h>
46#include <fcntl.h>
47#include <signal.h>
48#include <stdio.h>
49#include <string.h>
50#include <stdlib.h>
51#include <unistd.h>
52
53#define	SOCK_NAME	"test-sock"
54
55int	main(int, char *[]);
56void	child(void);
57void	catch_sigchld(int);
58void	usage(char *progname);
59
60#define	FILE_SIZE	128
61#define	MSG_SIZE	-1
62#define	NFILES		24
63
64#define	FDCM_DATASIZE	(sizeof(int) * NFILES)
65#define	CRCM_DATASIZE	(SOCKCREDSIZE(NGROUPS))
66
67#define	MESSAGE_SIZE	(CMSG_SPACE(FDCM_DATASIZE) +			\
68			 CMSG_SPACE(CRCM_DATASIZE))
69
70int chroot_rcvr = 0;
71int pass_dir = 0;
72int pass_root_dir = 0;
73int exit_early = 0;
74int exit_later = 0;
75int pass_sock = 0;
76int make_pretzel = 0;
77
78/* ARGSUSED */
79int
80main(argc, argv)
81	int argc;
82	char *argv[];
83{
84#if MSG_SIZE >= 0
85	struct iovec iov;
86#endif
87	char *progname=argv[0];
88	struct msghdr msg;
89	int listensock, sock, fd, i;
90	char fname[16], buf[FILE_SIZE];
91	struct cmsghdr *cmp;
92	void *message;
93	int *files = NULL;
94	struct sockcred *sc = NULL;
95	struct sockaddr_un sun, csun;
96	socklen_t csunlen;
97	pid_t pid;
98	int ch;
99
100	message = malloc(CMSG_SPACE(MESSAGE_SIZE));
101	if (message == NULL)
102		err(1, "unable to malloc message buffer");
103	memset(message, 0, CMSG_SPACE(MESSAGE_SIZE));
104
105	while ((ch = getopt(argc, argv, "DESdepr")) != -1) {
106		switch(ch) {
107
108		case 'e':
109			exit_early++; /* test early GC */
110			break;
111
112		case 'E':
113			exit_later++; /* test later GC */
114			break;
115
116		case 'd':
117			pass_dir++;
118			break;
119
120		case 'D':
121			pass_dir++;
122			pass_root_dir++;
123			break;
124
125		case 'S':
126			pass_sock++;
127			break;
128
129		case 'r':
130			chroot_rcvr++;
131			break;
132
133		case 'p':
134			make_pretzel++;
135			break;
136
137		case '?':
138		default:
139			usage(progname);
140		}
141	}
142
143
144	/*
145	 * Create the test files.
146	 */
147	for (i = 0; i < NFILES; i++) {
148		(void) sprintf(fname, "file%d", i + 1);
149		if ((fd = open(fname, O_WRONLY|O_CREAT|O_TRUNC, 0666)) == -1)
150			err(1, "open %s", fname);
151		(void) sprintf(buf, "This is file %d.\n", i + 1);
152		if (write(fd, buf, strlen(buf)) != strlen(buf))
153			err(1, "write %s", fname);
154		(void) close(fd);
155	}
156
157	/*
158	 * Create the listen socket.
159	 */
160	if ((listensock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
161		err(1, "socket");
162
163	(void) unlink(SOCK_NAME);
164	(void) memset(&sun, 0, sizeof(sun));
165	sun.sun_family = AF_LOCAL;
166	(void) strcpy(sun.sun_path, SOCK_NAME);
167	sun.sun_len = SUN_LEN(&sun);
168
169	i = 1;
170	if (setsockopt(listensock, 0, LOCAL_CREDS, &i, sizeof(i)) == -1)
171		err(1, "setsockopt");
172
173	if (bind(listensock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
174		err(1, "bind");
175
176	if (listen(listensock, 1) == -1)
177		err(1, "listen");
178
179	/*
180	 * Create the sender.
181	 */
182	(void) signal(SIGCHLD, catch_sigchld);
183	pid = fork();
184	switch (pid) {
185	case -1:
186		err(1, "fork");
187		/* NOTREACHED */
188
189	case 0:
190		child();
191		/* NOTREACHED */
192	}
193
194	if (exit_early)
195		exit(0);
196
197	if (chroot_rcvr &&
198	    ((chroot(".") < 0)))
199		err(1, "chroot");
200
201	/*
202	 * Wait for the sender to connect.
203	 */
204	csunlen = sizeof(csun);
205	if ((sock = accept(listensock, (struct sockaddr *)&csun,
206	    &csunlen)) == -1)
207		err(1, "accept");
208
209	/*
210	 * Give sender a chance to run.  We will get going again
211	 * once the SIGCHLD arrives.
212	 */
213	(void) sleep(10);
214
215	if (exit_later)
216		exit(0);
217
218	/*
219	 * Grab the descriptors and credentials passed to us.
220	 */
221
222	/* Expect 2 messages; descriptors and creds. */
223	do {
224		(void) memset(&msg, 0, sizeof(msg));
225		msg.msg_control = message;
226		msg.msg_controllen = MESSAGE_SIZE;
227#if MSG_SIZE >= 0
228		iov.iov_base = buf;
229		iov.iov_len = MSG_SIZE;
230		msg.msg_iov = &iov;
231		msg.msg_iovlen = 1;
232#endif
233
234		if (recvmsg(sock, &msg, 0) == -1)
235			err(1, "recvmsg");
236
237		(void) close(sock);
238		sock = -1;
239
240		if (msg.msg_controllen == 0)
241			errx(1, "no control messages received");
242
243		if (msg.msg_flags & MSG_CTRUNC)
244			errx(1, "lost control message data");
245
246		for (cmp = CMSG_FIRSTHDR(&msg); cmp != NULL;
247		     cmp = CMSG_NXTHDR(&msg, cmp)) {
248			if (cmp->cmsg_level != SOL_SOCKET)
249				errx(1, "bad control message level %d",
250				    cmp->cmsg_level);
251
252			switch (cmp->cmsg_type) {
253			case SCM_RIGHTS:
254				if (cmp->cmsg_len != CMSG_LEN(FDCM_DATASIZE))
255					errx(1, "bad fd control message "
256					    "length %d", cmp->cmsg_len);
257
258				files = (int *)CMSG_DATA(cmp);
259				break;
260
261			case SCM_CREDS:
262				if (cmp->cmsg_len < CMSG_LEN(SOCKCREDSIZE(1)))
263					errx(1, "bad cred control message "
264					    "length %d", cmp->cmsg_len);
265
266				sc = (struct sockcred *)CMSG_DATA(cmp);
267				break;
268
269			default:
270				errx(1, "unexpected control message");
271				/* NOTREACHED */
272			}
273		}
274
275		/*
276		 * Read the files and print their contents.
277		 */
278		if (files == NULL)
279			warnx("didn't get fd control message");
280		else {
281			for (i = 0; i < NFILES; i++) {
282				struct stat st;
283				(void) memset(buf, 0, sizeof(buf));
284				fstat(files[i], &st);
285				if (S_ISDIR(st.st_mode)) {
286					printf("file %d is a directory\n", i+1);
287				} else if (S_ISSOCK(st.st_mode)) {
288					printf("file %d is a socket\n", i+1);
289					sock = files[i];
290				} else {
291					int c;
292					c = read (files[i], buf, sizeof(buf));
293					if (c < 0)
294						err(1, "read file %d", i + 1);
295					else if (c == 0)
296						printf("[eof on %d]\n", i + 1);
297					else
298						printf("%s", buf);
299				}
300			}
301		}
302		/*
303		 * Double-check credentials.
304		 */
305		if (sc == NULL)
306			warnx("didn't get cred control message");
307		else {
308			if (sc->sc_uid == getuid() &&
309			    sc->sc_euid == geteuid() &&
310			    sc->sc_gid == getgid() &&
311			    sc->sc_egid == getegid())
312				printf("Credentials match.\n");
313			else
314				printf("Credentials do NOT match.\n");
315		}
316	} while (sock != -1);
317
318	/*
319	 * All done!
320	 */
321	exit(0);
322}
323
324void
325usage(progname)
326	char *progname;
327{
328	fprintf(stderr, "usage: %s [-derDES]\n", progname);
329	exit(1);
330}
331
332void
333catch_sigchld(sig)
334	int sig;
335{
336	int status;
337
338	(void) wait(&status);
339}
340
341void
342child()
343{
344#if MSG_SIZE >= 0
345	struct iovec iov;
346#endif
347	struct msghdr msg;
348	char fname[16];
349	struct cmsghdr *cmp;
350	void *fdcm;
351	int i, fd, sock, nfd, *files;
352	struct sockaddr_un sun;
353	int spair[2];
354
355	fdcm = malloc(CMSG_SPACE(FDCM_DATASIZE));
356	if (fdcm == NULL)
357		err(1, "unable to malloc fd control message");
358	memset(fdcm, 0, CMSG_SPACE(FDCM_DATASIZE));
359
360	cmp = fdcm;
361	files = (int *)CMSG_DATA(fdcm);
362
363	/*
364	 * Create socket and connect to the receiver.
365	 */
366	if ((sock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
367		errx(1, "child socket");
368
369	(void) memset(&sun, 0, sizeof(sun));
370	sun.sun_family = AF_LOCAL;
371	(void) strcpy(sun.sun_path, SOCK_NAME);
372	sun.sun_len = SUN_LEN(&sun);
373
374	if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
375		err(1, "child connect");
376
377	nfd = NFILES;
378	i = 0;
379
380	if (pass_sock) {
381		files[i++] = sock;
382	}
383
384	if (pass_dir)
385		nfd--;
386
387	/*
388	 * Open the files again, and pass them to the child
389	 * over the socket.
390	 */
391
392	for (; i < nfd; i++) {
393		(void) sprintf(fname, "file%d", i + 1);
394		if ((fd = open(fname, O_RDONLY, 0666)) == -1)
395			err(1, "child open %s", fname);
396		files[i] = fd;
397	}
398
399	if (pass_dir) {
400		char *dirname = pass_root_dir ? "/" : ".";
401
402
403		if ((fd = open(dirname, O_RDONLY, 0)) == -1) {
404			err(1, "child open directory %s", dirname);
405		}
406		files[i] = fd;
407	}
408
409	(void) memset(&msg, 0, sizeof(msg));
410	msg.msg_control = fdcm;
411	msg.msg_controllen = CMSG_LEN(FDCM_DATASIZE);
412#if MSG_SIZE >= 0
413	iov.iov_base = buf;
414	iov.iov_len = MSG_SIZE;
415	msg.msg_iov = &iov;
416	msg.msg_iovlen = 1;
417#endif
418
419	cmp = CMSG_FIRSTHDR(&msg);
420	cmp->cmsg_len = CMSG_LEN(FDCM_DATASIZE);
421	cmp->cmsg_level = SOL_SOCKET;
422	cmp->cmsg_type = SCM_RIGHTS;
423
424	while (make_pretzel > 0) {
425		if (socketpair(PF_LOCAL, SOCK_STREAM, 0, spair) < 0)
426			err(1, "socketpair");
427
428		printf("send pretzel\n");
429		if (sendmsg(spair[0], &msg, 0) < 0)
430			err(1, "child prezel sendmsg");
431
432		close(files[0]);
433		close(files[1]);
434		files[0] = spair[0];
435		files[1] = spair[1];
436		make_pretzel--;
437	}
438
439	if (sendmsg(sock, &msg, 0) == -1)
440		err(1, "child sendmsg");
441
442	/*
443	 * All done!
444	 */
445	exit(0);
446}
447