1/*
2 * Copyright (c) 1995-2003 Kungliga Tekniska Högskolan
3 * (Royal Institute of Technology, Stockholm, Sweden).
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * 3. Neither the name of the Institute nor the names of its contributors
18 *    may be used to endorse or promote products derived from this software
19 *    without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24 * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31 * SUCH DAMAGE.
32 */
33
34#include "kx.h"
35
36RCSID("$Id$");
37
38static int nchild;
39static int donep;
40
41/*
42 * Signal handler that justs waits for the children when they die.
43 */
44
45static RETSIGTYPE
46childhandler (int sig)
47{
48     pid_t pid;
49     int status;
50
51     do {
52	 pid = waitpid (-1, &status, WNOHANG|WUNTRACED);
53	 if (pid > 0 && (WIFEXITED(status) || WIFSIGNALED(status)))
54	     if (--nchild == 0 && donep)
55		 exit (0);
56     } while(pid > 0);
57     signal (SIGCHLD, childhandler);
58     SIGRETURN(0);
59}
60
61/*
62 * Handler for SIGUSR1.
63 * This signal means that we should wait until there are no children
64 * left and then exit.
65 */
66
67static RETSIGTYPE
68usr1handler (int sig)
69{
70    donep = 1;
71
72    SIGRETURN(0);
73}
74
75/*
76 * Almost the same as for SIGUSR1, except we should exit immediately
77 * if there are no active children.
78 */
79
80static RETSIGTYPE
81usr2handler (int sig)
82{
83    donep = 1;
84    if (nchild == 0)
85	exit (0);
86
87    SIGRETURN(0);
88}
89
90/*
91 * Establish authenticated connection.  Return socket or -1.
92 */
93
94static int
95connect_host (kx_context *kc)
96{
97    struct addrinfo *ai, *a;
98    struct addrinfo hints;
99    int error;
100    char portstr[NI_MAXSERV];
101    socklen_t addrlen;
102    int s = -1;
103    struct sockaddr_storage thisaddr_ss;
104    struct sockaddr *thisaddr = (struct sockaddr *)&thisaddr_ss;
105
106    memset (&hints, 0, sizeof(hints));
107    hints.ai_socktype = SOCK_STREAM;
108    hints.ai_protocol = IPPROTO_TCP;
109
110    snprintf (portstr, sizeof(portstr), "%u", ntohs(kc->port));
111
112    error = getaddrinfo (kc->host, portstr, &hints, &ai);
113    if (error) {
114	warnx ("%s: %s", kc->host, gai_strerror(error));
115	return -1;
116    }
117
118    for (a = ai; a != NULL; a = a->ai_next) {
119	s = socket (a->ai_family, a->ai_socktype, a->ai_protocol);
120	if (s < 0)
121	    continue;
122	if (connect (s, a->ai_addr, a->ai_addrlen) < 0) {
123	    warn ("connect(%s)", kc->host);
124	    close (s);
125	    continue;
126	}
127	break;
128    }
129
130    if (a == NULL) {
131	freeaddrinfo (ai);
132	return -1;
133    }
134
135    addrlen = sizeof(thisaddr_ss);
136    if (getsockname (s, thisaddr, &addrlen) < 0 ||
137	addrlen != a->ai_addrlen)
138	err(1, "getsockname(%s)", kc->host);
139    memcpy (&kc->__ss_this, thisaddr, sizeof(kc->__ss_this));
140    kc->thisaddr_len = addrlen;
141    memcpy (&kc->__ss_that, a->ai_addr, sizeof(kc->__ss_that));
142    kc->thataddr_len = a->ai_addrlen;
143    freeaddrinfo (ai);
144    if ((*kc->authenticate)(kc, s))
145	return -1;
146    return s;
147}
148
149/*
150 * Get rid of the cookie that we were sent and get the correct one
151 * from our own cookie file instead and then just copy data in both
152 * directions.
153 */
154
155static int
156passive_session (int xserver, int fd, kx_context *kc)
157{
158    if (replace_cookie (xserver, fd, XauFileName(), 1))
159	return 1;
160    else
161	return copy_encrypted (kc, xserver, fd);
162}
163
164static int
165active_session (int xserver, int fd, kx_context *kc)
166{
167    if (verify_and_remove_cookies (xserver, fd, 1))
168	return 1;
169    else
170	return copy_encrypted (kc, xserver, fd);
171}
172
173/*
174 * fork (unless debugp) and print the output that will be used by the
175 * script to capture the display, xauth cookie and pid.
176 */
177
178static void
179status_output (int debugp)
180{
181    if(debugp)
182	printf ("%u\t%s\t%s\n", (unsigned)getpid(), display, xauthfile);
183    else {
184	pid_t pid;
185
186	pid = fork();
187	if (pid < 0) {
188	    err(1, "fork");
189	} else if (pid > 0) {
190	    printf ("%u\t%s\t%s\n", (unsigned)pid, display, xauthfile);
191	    exit (0);
192	} else {
193	    fclose(stdout);
194	}
195    }
196}
197
198/*
199 * Obtain an authenticated connection on `kc'.  Send a kx message
200 * saying we are `kc->user' and want to use passive mode.  Wait for
201 * answer on that connection and fork of a child for every new
202 * connection we have to make.
203 */
204
205static int
206doit_passive (kx_context *kc)
207{
208     int otherside;
209     u_char msg[1024], *p;
210     int len;
211     uint32_t tmp;
212     const char *host = kc->host;
213
214     otherside = connect_host (kc);
215
216     if (otherside < 0)
217	 return 1;
218#if defined(SO_KEEPALIVE) && defined(HAVE_SETSOCKOPT)
219     if (kc->keepalive_flag) {
220	 int one = 1;
221
222	 setsockopt (otherside, SOL_SOCKET, SO_KEEPALIVE, (void *)&one,
223		     sizeof(one));
224     }
225#endif
226
227     p = msg;
228     *p++ = INIT;
229     len = strlen(kc->user);
230     p += kx_put_int (len, p, sizeof(msg) - 1, 4);
231     memcpy(p, kc->user, len);
232     p += len;
233     *p++ = PASSIVE | (kc->keepalive_flag ? KEEP_ALIVE : 0);
234     if (kx_write (kc, otherside, msg, p - msg) != p - msg)
235	 err (1, "write to %s", host);
236     len = kx_read (kc, otherside, msg, sizeof(msg));
237     if (len <= 0)
238	 errx (1,
239	       "error reading initial message from %s: "
240	       "this probably means it's using an old version.",
241	       host);
242     p = (u_char *)msg;
243     if (*p == ERROR) {
244	 p++;
245	 p += kx_get_int (p, &tmp, 4, 0);
246	 errx (1, "%s: %.*s", host, (int)tmp, p);
247     } else if (*p != ACK) {
248	 errx (1, "%s: strange msg %d", host, *p);
249     } else
250	 p++;
251     p += kx_get_int (p, &tmp, 4, 0);
252     memcpy(display, p, tmp);
253     display[tmp] = '\0';
254     p += tmp;
255
256     p += kx_get_int (p, &tmp, 4, 0);
257     memcpy(xauthfile, p, tmp);
258     xauthfile[tmp] = '\0';
259     p += tmp;
260
261     status_output (kc->debug_flag);
262     for (;;) {
263	 pid_t child;
264
265	 len = kx_read (kc, otherside, msg, sizeof(msg));
266	 if (len < 0)
267	     err (1, "read from %s", host);
268	 else if (len == 0)
269	     return 0;
270
271	 p = (u_char *)msg;
272	 if (*p == ERROR) {
273	     p++;
274	     p += kx_get_int (p, &tmp, 4, 0);
275	     errx (1, "%s: %.*s", host, (int)tmp, p);
276	 } else if(*p != NEW_CONN) {
277	     errx (1, "%s: strange msg %d", host, *p);
278	 } else {
279	     p++;
280	     p += kx_get_int (p, &tmp, 4, 0);
281	 }
282
283	 ++nchild;
284	 child = fork ();
285	 if (child < 0) {
286	     warn("fork");
287	     continue;
288	 } else if (child == 0) {
289	     int fd;
290	     int xserver;
291
292	     close (otherside);
293
294	     socket_set_port(kc->thataddr, htons(tmp));
295
296	     fd = socket (kc->thataddr->sa_family, SOCK_STREAM, 0);
297	     if (fd < 0)
298		 err(1, "socket");
299#if defined(TCP_NODELAY) && defined(HAVE_SETSOCKOPT)
300	     {
301		 int one = 1;
302
303		 setsockopt (fd, IPPROTO_TCP, TCP_NODELAY, (void *)&one,
304			     sizeof(one));
305	     }
306#endif
307#if defined(SO_KEEPALIVE) && defined(HAVE_SETSOCKOPT)
308	     if (kc->keepalive_flag) {
309		 int one = 1;
310
311		 setsockopt (fd, SOL_SOCKET, SO_KEEPALIVE, (void *)&one,
312			     sizeof(one));
313	     }
314#endif
315
316	     if (connect (fd, kc->thataddr, kc->thataddr_len) < 0)
317		 err(1, "connect(%s)", host);
318	     {
319		 int d = 0;
320		 char *s;
321
322		 s = getenv ("DISPLAY");
323		 if (s != NULL) {
324		     s = strchr (s, ':');
325		     if (s != NULL)
326			 d = atoi (s + 1);
327		 }
328
329		 xserver = connect_local_xsocket (d);
330		 if (xserver < 0)
331		     return 1;
332	     }
333	     return passive_session (xserver, fd, kc);
334	 } else {
335	 }
336     }
337}
338
339/*
340 * Allocate a local pseudo-xserver and wait for connections
341 */
342
343static int
344doit_active (kx_context *kc)
345{
346    int otherside;
347    int nsockets;
348    struct x_socket *sockets;
349    u_char msg[1024], *p;
350    int len;
351    int tmp, tmp2;
352    char *str;
353    int i;
354    size_t rem;
355    uint32_t other_port;
356    int error;
357    const char *host = kc->host;
358
359    otherside = connect_host (kc);
360    if (otherside < 0)
361	return 1;
362#if defined(SO_KEEPALIVE) && defined(HAVE_SETSOCKOPT)
363    if (kc->keepalive_flag) {
364	int one = 1;
365
366	setsockopt (otherside, SOL_SOCKET, SO_KEEPALIVE, (void *)&one,
367		    sizeof(one));
368    }
369#endif
370    p = msg;
371    rem = sizeof(msg);
372    *p++ = INIT;
373    --rem;
374    len = strlen(kc->user);
375    tmp = kx_put_int (len, p, rem, 4);
376    if (tmp < 0)
377	return 1;
378    p += tmp;
379    rem -= tmp;
380    memcpy(p, kc->user, len);
381    p += len;
382    rem -= len;
383    *p++ = (kc->keepalive_flag ? KEEP_ALIVE : 0);
384    --rem;
385
386    str = getenv("DISPLAY");
387    if (str == NULL || (str = strchr(str, ':')) == NULL)
388	str = ":0";
389    len = strlen (str);
390    tmp = kx_put_int (len, p, rem, 4);
391    if (tmp < 0)
392	return 1;
393    rem -= tmp;
394    p += tmp;
395    memcpy (p, str, len);
396    p += len;
397    rem -= len;
398
399    str = getenv("XAUTHORITY");
400    if (str == NULL)
401	str = "";
402    len = strlen (str);
403    tmp = kx_put_int (len, p, rem, 4);
404    if (tmp < 0)
405	return 1;
406    p += len;
407    rem -= len;
408    memcpy (p, str, len);
409    p += len;
410    rem -= len;
411
412    if (kx_write (kc, otherside, msg, p - msg) != p - msg)
413	err (1, "write to %s", host);
414
415    len = kx_read (kc, otherside, msg, sizeof(msg));
416    if (len < 0)
417	err (1, "read from %s", host);
418    p = (u_char *)msg;
419    if (*p == ERROR) {
420	uint32_t u32;
421
422	p++;
423	p += kx_get_int (p, &u32, 4, 0);
424	errx (1, "%s: %.*s", host, (int)u32, p);
425    } else if (*p != ACK) {
426	errx (1, "%s: strange msg %d", host, *p);
427    }
428
429    tmp2 = get_xsockets (&nsockets, &sockets, kc->tcp_flag);
430    if (tmp2 < 0)
431	errx(1, "Failed to open sockets");
432    display_num = tmp2;
433    if (kc->tcp_flag)
434	snprintf (display, display_size, "localhost:%u", display_num);
435    else
436	snprintf (display, display_size, ":%u", display_num);
437    error = create_and_write_cookie (xauthfile, xauthfile_size,
438				     cookie, cookie_len);
439    if (error)
440	errx(1, "failed creating cookie file: %s", strerror(error));
441
442    status_output (kc->debug_flag);
443    for (;;) {
444	fd_set fdset;
445	pid_t child;
446	int fd, thisfd = -1;
447	socklen_t zero = 0;
448
449	FD_ZERO(&fdset);
450	for (i = 0; i < nsockets; ++i) {
451	    if (sockets[i].fd >= FD_SETSIZE)
452		errx (1, "fd too large");
453	    FD_SET(sockets[i].fd, &fdset);
454	}
455	if (select(FD_SETSIZE, &fdset, NULL, NULL, NULL) <= 0)
456	    continue;
457	for (i = 0; i < nsockets; ++i)
458	    if (FD_ISSET(sockets[i].fd, &fdset)) {
459		thisfd = sockets[i].fd;
460		break;
461	    }
462	fd = accept (thisfd, NULL, &zero);
463	if (fd < 0) {
464	    if (errno == EINTR)
465		continue;
466	    else
467		err(1, "accept");
468	}
469
470	p = msg;
471	*p++ = NEW_CONN;
472	if (kx_write (kc, otherside, msg, p - msg) != p - msg)
473	    err (1, "write to %s", host);
474	len = kx_read (kc, otherside, msg, sizeof(msg));
475	if (len < 0)
476	    err (1, "read from %s", host);
477	p = (u_char *)msg;
478	if (*p == ERROR) {
479	    uint32_t val;
480
481	    p++;
482	    p += kx_get_int (p, &val, 4, 0);
483	    errx (1, "%s: %.*s", host, (int)val, p);
484	} else if (*p != NEW_CONN) {
485	    errx (1, "%s: strange msg %d", host, *p);
486	} else {
487	    p++;
488	    p += kx_get_int (p, &other_port, 4, 0);
489	}
490
491	++nchild;
492	child = fork ();
493	if (child < 0) {
494	    warn("fork");
495	    continue;
496	} else if (child == 0) {
497	    int s;
498
499	    for (i = 0; i < nsockets; ++i)
500		close (sockets[i].fd);
501
502	    close (otherside);
503
504	    socket_set_port(kc->thataddr, htons(tmp));
505
506	    s = socket (kc->thataddr->sa_family, SOCK_STREAM, 0);
507	    if (s < 0)
508		err(1, "socket");
509#if defined(TCP_NODELAY) && defined(HAVE_SETSOCKOPT)
510	    {
511		int one = 1;
512
513		setsockopt (s, IPPROTO_TCP, TCP_NODELAY, (void *)&one,
514			    sizeof(one));
515	    }
516#endif
517#if defined(SO_KEEPALIVE) && defined(HAVE_SETSOCKOPT)
518	    if (kc->keepalive_flag) {
519		int one = 1;
520
521		setsockopt (s, SOL_SOCKET, SO_KEEPALIVE, (void *)&one,
522			    sizeof(one));
523	    }
524#endif
525
526	    if (connect (s, kc->thataddr, kc->thataddr_len) < 0)
527		err(1, "connect");
528
529	    return active_session (fd, s, kc);
530	} else {
531	    close (fd);
532	}
533    }
534}
535
536/*
537 * Should we interpret `disp' as this being a passive call?
538 */
539
540static int
541check_for_passive (const char *disp)
542{
543    char local_hostname[MaxHostNameLen];
544
545    gethostname (local_hostname, sizeof(local_hostname));
546
547    return disp != NULL &&
548	(*disp == ':'
549	 || strncmp(disp, "unix", 4) == 0
550	 || strncmp(disp, "localhost", 9) == 0
551	 || strncmp(disp, local_hostname, strlen(local_hostname)) == 0);
552}
553
554/*
555 * Set up signal handlers and then call the functions.
556 */
557
558static int
559doit (kx_context *kc, int passive_flag)
560{
561    signal (SIGCHLD, childhandler);
562    signal (SIGUSR1, usr1handler);
563    signal (SIGUSR2, usr2handler);
564    if (passive_flag)
565	return doit_passive (kc);
566    else
567	return doit_active  (kc);
568}
569
570#ifdef KRB5
571
572/*
573 * Start a v5-authenticatated kx connection.
574 */
575
576static int
577doit_v5 (const char *host, int port, const char *user,
578	 int passive_flag, int debug_flag, int keepalive_flag, int tcp_flag)
579{
580    int ret;
581    kx_context context;
582
583    krb5_make_context (&context);
584    context_set (&context,
585		 host, user, port, debug_flag, keepalive_flag, tcp_flag);
586
587    ret = doit (&context, passive_flag);
588    context_destroy (&context);
589    return ret;
590}
591#endif /* KRB5 */
592
593/*
594 * Variables set from the arguments
595 */
596
597#ifdef KRB5
598static int use_v5		= -1;
599#endif
600static char *port_str		= NULL;
601static const char *user		= NULL;
602static int tcp_flag		= 0;
603static int passive_flag		= 0;
604static int keepalive_flag	= 1;
605static int debug_flag		= 0;
606static int version_flag		= 0;
607static int help_flag		= 0;
608
609struct getargs args[] = {
610#ifdef KRB5
611    { "krb5",	'5', arg_flag,		&use_v5,	"Use Kerberos V5",
612      NULL },
613#endif
614    { "port",	'p', arg_string,	&port_str,	"Use this port",
615      "number-of-service" },
616    { "user",	'l', arg_string,	&user,		"Run as this user",
617      NULL },
618    { "tcp",	't', arg_flag,		&tcp_flag,
619      "Use a TCP connection for X11" },
620    { "passive", 'P', arg_flag,		&passive_flag,
621      "Force a passive connection" },
622    { "keepalive", 'k', arg_negative_flag, &keepalive_flag,
623      "disable keep-alives" },
624    { "debug",	'd',	arg_flag,	&debug_flag,
625      "Enable debug information" },
626    { "version", 0,  arg_flag,		&version_flag,	"Print version",
627      NULL },
628    { "help",	 0,  arg_flag,		&help_flag,	NULL,
629      NULL }
630};
631
632static void
633usage(int ret)
634{
635    arg_printusage (args,
636		    sizeof(args) / sizeof(args[0]),
637		    NULL,
638		    "host");
639    exit (ret);
640}
641
642/*
643 * kx - forward an x-connection over a kerberos-encrypted channel.
644 */
645
646int
647main(int argc, char **argv)
648{
649    int port	= 0;
650    int optidx	= 0;
651    int ret	= 1;
652    char *host	= NULL;
653
654    setprogname (argv[0]);
655
656    if (getarg (args, sizeof(args) / sizeof(args[0]), argc, argv,
657		&optidx))
658	usage (1);
659
660    if (help_flag)
661	usage (0);
662
663    if (version_flag) {
664	print_version (NULL);
665	return 0;
666    }
667
668    if (optidx != argc - 1)
669	usage (1);
670
671    host = argv[optidx];
672
673    if (port_str) {
674	struct servent *s = roken_getservbyname (port_str, "tcp");
675
676	if (s)
677	    port = s->s_port;
678	else {
679	    char *ptr;
680
681	    port = strtol (port_str, &ptr, 10);
682	    if (port == 0 && ptr == port_str)
683		errx (1, "Bad port `%s'", port_str);
684	    port = htons(port);
685	}
686    }
687
688    if (user == NULL) {
689	user = get_default_username ();
690	if (user == NULL)
691	    errx (1, "who are you?");
692    }
693
694    if (!passive_flag)
695	passive_flag = check_for_passive (getenv("DISPLAY"));
696
697#if defined(HAVE_KERNEL_ENABLE_DEBUG)
698    if (krb_debug_flag)
699	krb_enable_debug ();
700#endif
701
702#ifdef KRB5
703    if (ret && use_v5) {
704	if (port == 0)
705	    port = krb5_getportbyname(NULL, "kx", "tcp", KX_PORT);
706	ret = doit_v5 (host, port, user,
707		       passive_flag, debug_flag, keepalive_flag, tcp_flag);
708    }
709#endif
710    return ret;
711}
712