proto_tls.c revision 243730
1/*-
2 * Copyright (c) 2011 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 *
29 * $P4: //depot/projects/trustedbsd/openbsm/bin/auditdistd/proto_tls.c#1 $
30 */
31
32#ifdef HAVE_CONFIG_H
33#include "config.h"
34#endif
35
36#include <sys/param.h>	/* MAXHOSTNAMELEN */
37#include <sys/socket.h>
38
39#include <arpa/inet.h>
40
41#include <netinet/in.h>
42#include <netinet/tcp.h>
43
44#include <errno.h>
45#include <fcntl.h>
46#include <netdb.h>
47#include <signal.h>
48#include <stdbool.h>
49#include <stdint.h>
50#include <stdio.h>
51#include <string.h>
52#include <unistd.h>
53
54#include <openssl/err.h>
55#include <openssl/ssl.h>
56
57#include <compat/compat.h>
58#ifndef HAVE_CLOSEFROM
59#include <compat/closefrom.h>
60#endif
61#ifndef HAVE_STRLCPY
62#include <compat/strlcpy.h>
63#endif
64
65#include "pjdlog.h"
66#include "proto_impl.h"
67#include "sandbox.h"
68#include "subr.h"
69
70#define	TLS_CTX_MAGIC	0x715c7
71struct tls_ctx {
72	int		tls_magic;
73	struct proto_conn *tls_sock;
74	struct proto_conn *tls_tcp;
75	char		tls_laddr[256];
76	char		tls_raddr[256];
77	int		tls_side;
78#define	TLS_SIDE_CLIENT		0
79#define	TLS_SIDE_SERVER_LISTEN	1
80#define	TLS_SIDE_SERVER_WORK	2
81	bool		tls_wait_called;
82};
83
84#define	TLS_DEFAULT_TIMEOUT	30
85
86static int tls_connect_wait(void *ctx, int timeout);
87static void tls_close(void *ctx);
88
89static void
90block(int fd)
91{
92	int flags;
93
94	flags = fcntl(fd, F_GETFL);
95	if (flags == -1)
96		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
97	flags &= ~O_NONBLOCK;
98	if (fcntl(fd, F_SETFL, flags) == -1)
99		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
100}
101
102static void
103nonblock(int fd)
104{
105	int flags;
106
107	flags = fcntl(fd, F_GETFL);
108	if (flags == -1)
109		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
110	flags |= O_NONBLOCK;
111	if (fcntl(fd, F_SETFL, flags) == -1)
112		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
113}
114
115static int
116wait_for_fd(int fd, int timeout)
117{
118	struct timeval tv;
119	fd_set fdset;
120	int error, ret;
121
122	error = 0;
123
124	for (;;) {
125		FD_ZERO(&fdset);
126		FD_SET(fd, &fdset);
127
128		tv.tv_sec = timeout;
129		tv.tv_usec = 0;
130
131		ret = select(fd + 1, NULL, &fdset, NULL,
132		    timeout == -1 ? NULL : &tv);
133		if (ret == 0) {
134			error = ETIMEDOUT;
135			break;
136		} else if (ret == -1) {
137			if (errno == EINTR)
138				continue;
139			error = errno;
140			break;
141		}
142		PJDLOG_ASSERT(ret > 0);
143		PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
144		break;
145	}
146
147	return (error);
148}
149
150static void
151ssl_log_errors(void)
152{
153	unsigned long error;
154
155	while ((error = ERR_get_error()) != 0)
156		pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
157}
158
159static int
160ssl_check_error(SSL *ssl, int ret)
161{
162	int error;
163
164	error = SSL_get_error(ssl, ret);
165
166	switch (error) {
167	case SSL_ERROR_NONE:
168		return (0);
169	case SSL_ERROR_WANT_READ:
170		pjdlog_debug(2, "SSL_ERROR_WANT_READ");
171		return (-1);
172	case SSL_ERROR_WANT_WRITE:
173		pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
174		return (-1);
175	case SSL_ERROR_ZERO_RETURN:
176		pjdlog_exitx(EX_OK, "Connection closed.");
177	case SSL_ERROR_SYSCALL:
178		ssl_log_errors();
179		pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
180	case SSL_ERROR_SSL:
181		ssl_log_errors();
182		pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
183	default:
184		ssl_log_errors();
185		pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
186	}
187}
188
189static void
190tcp_recv_ssl_send(int recvfd, SSL *sendssl)
191{
192	static unsigned char buf[65536];
193	ssize_t tcpdone;
194	int sendfd, ssldone;
195
196	sendfd = SSL_get_fd(sendssl);
197	PJDLOG_ASSERT(sendfd >= 0);
198	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
199	for (;;) {
200		tcpdone = recv(recvfd, buf, sizeof(buf), 0);
201		pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
202		if (tcpdone == 0) {
203			pjdlog_debug(1, "Connection terminated.");
204			exit(0);
205		} else if (tcpdone == -1) {
206			if (errno == EINTR)
207				continue;
208			else if (errno == EAGAIN)
209				break;
210			pjdlog_exit(EX_TEMPFAIL, "recv() failed");
211		}
212		for (;;) {
213			ssldone = SSL_write(sendssl, buf, (int)tcpdone);
214			pjdlog_debug(2, "%s: send() returned %d", __func__,
215			    ssldone);
216			if (ssl_check_error(sendssl, ssldone) == -1) {
217				(void)wait_for_fd(sendfd, -1);
218				continue;
219			}
220			PJDLOG_ASSERT(ssldone == tcpdone);
221			break;
222		}
223	}
224	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
225}
226
227static void
228ssl_recv_tcp_send(SSL *recvssl, int sendfd)
229{
230	static unsigned char buf[65536];
231	unsigned char *ptr;
232	ssize_t tcpdone;
233	size_t todo;
234	int recvfd, ssldone;
235
236	recvfd = SSL_get_fd(recvssl);
237	PJDLOG_ASSERT(recvfd >= 0);
238	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
239	for (;;) {
240		ssldone = SSL_read(recvssl, buf, sizeof(buf));
241		pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
242		    ssldone);
243		if (ssl_check_error(recvssl, ssldone) == -1)
244			break;
245		todo = (size_t)ssldone;
246		ptr = buf;
247		do {
248			tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
249			pjdlog_debug(2, "%s: send() returned %zd", __func__,
250			    tcpdone);
251			if (tcpdone == 0) {
252				pjdlog_debug(1, "Connection terminated.");
253				exit(0);
254			} else if (tcpdone == -1) {
255				if (errno == EINTR || errno == ENOBUFS)
256					continue;
257				if (errno == EAGAIN) {
258					(void)wait_for_fd(sendfd, -1);
259					continue;
260				}
261				pjdlog_exit(EX_TEMPFAIL, "send() failed");
262			}
263			todo -= tcpdone;
264			ptr += tcpdone;
265		} while (todo > 0);
266	}
267	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
268}
269
270static void
271tls_loop(int sockfd, SSL *tcpssl)
272{
273	fd_set fds;
274	int maxfd, tcpfd;
275
276	tcpfd = SSL_get_fd(tcpssl);
277	PJDLOG_ASSERT(tcpfd >= 0);
278
279	for (;;) {
280		FD_ZERO(&fds);
281		FD_SET(sockfd, &fds);
282		FD_SET(tcpfd, &fds);
283		maxfd = MAX(sockfd, tcpfd);
284
285		PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
286		if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
287			if (errno == EINTR)
288				continue;
289			pjdlog_exit(EX_TEMPFAIL, "select() failed");
290		}
291		if (FD_ISSET(sockfd, &fds))
292			tcp_recv_ssl_send(sockfd, tcpssl);
293		if (FD_ISSET(tcpfd, &fds))
294			ssl_recv_tcp_send(tcpssl, sockfd);
295	}
296}
297
298static void
299tls_certificate_verify(SSL *ssl, const char *fingerprint)
300{
301	unsigned char md[EVP_MAX_MD_SIZE];
302	char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
303	char *mdstrp;
304	unsigned int i, mdsize;
305	X509 *cert;
306
307	if (fingerprint[0] == '\0') {
308		pjdlog_debug(1, "No fingerprint verification requested.");
309		return;
310	}
311
312	cert = SSL_get_peer_certificate(ssl);
313	if (cert == NULL)
314		pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
315
316	if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
317		pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
318	PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
319
320	X509_free(cert);
321
322	(void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
323	mdstrp = mdstr + strlen(mdstr);
324	for (i = 0; i < mdsize; i++) {
325		PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
326		(void)sprintf(mdstrp, "%02hhX:", md[i]);
327		mdstrp += 3;
328	}
329	/* Clear last colon. */
330	mdstrp[-1] = '\0';
331	if (strcasecmp(mdstr, fingerprint) != 0) {
332		pjdlog_exitx(EX_NOPERM,
333		    "Finger print doesn't match. Received \"%s\", expected \"%s\"",
334		    mdstr, fingerprint);
335	}
336}
337
338static void
339tls_exec_client(const char *user, int startfd, const char *srcaddr,
340    const char *dstaddr, const char *fingerprint, const char *defport,
341    int timeout, int debuglevel)
342{
343	struct proto_conn *tcp;
344	char *saddr, *daddr;
345	SSL_CTX *sslctx;
346	SSL *ssl;
347	long ret;
348	int sockfd, tcpfd;
349	uint8_t connected;
350
351	pjdlog_debug_set(debuglevel);
352	pjdlog_prefix_set("[TLS sandbox] (client) ");
353#ifdef HAVE_SETPROCTITLE
354	setproctitle("[TLS sandbox] (client) ");
355#endif
356	proto_set("tcp:port", defport);
357
358	sockfd = startfd;
359
360	/* Change tls:// to tcp://. */
361	if (srcaddr == NULL) {
362		saddr = NULL;
363	} else {
364		saddr = strdup(srcaddr);
365		if (saddr == NULL)
366			pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
367		bcopy("tcp://", saddr, 6);
368	}
369	daddr = strdup(dstaddr);
370	if (daddr == NULL)
371		pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
372	bcopy("tcp://", daddr, 6);
373
374	/* Establish TCP connection. */
375	if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
376		exit(EX_TEMPFAIL);
377
378	SSL_load_error_strings();
379	SSL_library_init();
380
381	/*
382	 * TODO: On FreeBSD we could move this below sandbox() once libc and
383	 *       libcrypto use sysctl kern.arandom to obtain random data
384	 *       instead of /dev/urandom and friends.
385	 */
386	sslctx = SSL_CTX_new(TLSv1_client_method());
387	if (sslctx == NULL)
388		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
389
390	if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
391		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
392	pjdlog_debug(1, "Privileges successfully dropped.");
393
394	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
395
396	/* Load CA certs. */
397	/* TODO */
398	//SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
399
400	ssl = SSL_new(sslctx);
401	if (ssl == NULL)
402		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
403
404	tcpfd = proto_descriptor(tcp);
405
406	block(tcpfd);
407
408	if (SSL_set_fd(ssl, tcpfd) != 1)
409		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
410
411	ret = SSL_connect(ssl);
412	ssl_check_error(ssl, (int)ret);
413
414	nonblock(sockfd);
415	nonblock(tcpfd);
416
417	tls_certificate_verify(ssl, fingerprint);
418
419	/*
420	 * The following byte is send to make proto_connect_wait() to work.
421	 */
422	connected = 1;
423	for (;;) {
424		switch (send(sockfd, &connected, sizeof(connected), 0)) {
425		case -1:
426			if (errno == EINTR || errno == ENOBUFS)
427				continue;
428			if (errno == EAGAIN) {
429				(void)wait_for_fd(sockfd, -1);
430				continue;
431			}
432			pjdlog_exit(EX_TEMPFAIL, "send() failed");
433		case 0:
434			pjdlog_debug(1, "Connection terminated.");
435			exit(0);
436		case 1:
437			break;
438		}
439		break;
440	}
441
442	tls_loop(sockfd, ssl);
443}
444
445static void
446tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
447    const char *dstaddr, int timeout)
448{
449	char *timeoutstr, *startfdstr, *debugstr;
450	int startfd;
451
452	/* Declare that we are receiver. */
453	proto_recv(sock, NULL, 0);
454
455	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
456		startfd = 3;
457	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
458		startfd = 0;
459
460	if (proto_descriptor(sock) != startfd) {
461		/* Move socketpair descriptor to descriptor number startfd. */
462		if (dup2(proto_descriptor(sock), startfd) == -1)
463			pjdlog_exit(EX_OSERR, "dup2() failed");
464		proto_close(sock);
465	} else {
466		/*
467		 * The FD_CLOEXEC is cleared by dup2(2), so when we not
468		 * call it, we have to clear it by hand in case it is set.
469		 */
470		if (fcntl(startfd, F_SETFD, 0) == -1)
471			pjdlog_exit(EX_OSERR, "fcntl() failed");
472	}
473
474	closefrom(startfd + 1);
475
476	if (asprintf(&startfdstr, "%d", startfd) == -1)
477		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
478	if (timeout == -1)
479		timeout = TLS_DEFAULT_TIMEOUT;
480	if (asprintf(&timeoutstr, "%d", timeout) == -1)
481		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
482	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
483		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
484
485	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
486	    proto_get("user"), "client", startfdstr,
487	    srcaddr == NULL ? "" : srcaddr, dstaddr,
488	    proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
489	    debugstr, NULL);
490	pjdlog_exit(EX_SOFTWARE, "execl() failed");
491}
492
493static int
494tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
495{
496	struct tls_ctx *tlsctx;
497	struct proto_conn *sock;
498	pid_t pid;
499	int error;
500
501	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
502	PJDLOG_ASSERT(dstaddr != NULL);
503	PJDLOG_ASSERT(timeout >= -1);
504	PJDLOG_ASSERT(ctxp != NULL);
505
506	if (strncmp(dstaddr, "tls://", 6) != 0)
507		return (-1);
508	if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
509		return (-1);
510
511	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
512		return (errno);
513
514#if 0
515	/*
516	 * We use rfork() with the following flags to disable SIGCHLD
517	 * delivery upon the sandbox process exit.
518	 */
519	pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
520#else
521	/*
522	 * We don't use rfork() to be able to log information about sandbox
523	 * process exiting.
524	 */
525	pid = fork();
526#endif
527	switch (pid) {
528	case -1:
529		/* Failure. */
530		error = errno;
531		proto_close(sock);
532		return (error);
533	case 0:
534		/* Child. */
535		pjdlog_prefix_set("[TLS sandbox] (client) ");
536#ifdef HAVE_SETPROCTITLE
537		setproctitle("[TLS sandbox] (client) ");
538#endif
539		tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
540		/* NOTREACHED */
541	default:
542		/* Parent. */
543		tlsctx = calloc(1, sizeof(*tlsctx));
544		if (tlsctx == NULL) {
545			error = errno;
546			proto_close(sock);
547			(void)kill(pid, SIGKILL);
548			return (error);
549		}
550		proto_send(sock, NULL, 0);
551		tlsctx->tls_sock = sock;
552		tlsctx->tls_tcp = NULL;
553		tlsctx->tls_side = TLS_SIDE_CLIENT;
554		tlsctx->tls_wait_called = false;
555		tlsctx->tls_magic = TLS_CTX_MAGIC;
556		if (timeout >= 0) {
557			error = tls_connect_wait(tlsctx, timeout);
558			if (error != 0) {
559				(void)kill(pid, SIGKILL);
560				tls_close(tlsctx);
561				return (error);
562			}
563		}
564		*ctxp = tlsctx;
565		return (0);
566	}
567}
568
569static int
570tls_connect_wait(void *ctx, int timeout)
571{
572	struct tls_ctx *tlsctx = ctx;
573	int error, sockfd;
574	uint8_t connected;
575
576	PJDLOG_ASSERT(tlsctx != NULL);
577	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
578	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
579	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
580	PJDLOG_ASSERT(!tlsctx->tls_wait_called);
581	PJDLOG_ASSERT(timeout >= 0);
582
583	sockfd = proto_descriptor(tlsctx->tls_sock);
584	error = wait_for_fd(sockfd, timeout);
585	if (error != 0)
586		return (error);
587
588	for (;;) {
589		switch (recv(sockfd, &connected, sizeof(connected),
590		    MSG_WAITALL)) {
591		case -1:
592			if (errno == EINTR || errno == ENOBUFS)
593				continue;
594			error = errno;
595			break;
596		case 0:
597			pjdlog_debug(1, "Connection terminated.");
598			error = ENOTCONN;
599			break;
600		case 1:
601			tlsctx->tls_wait_called = true;
602			break;
603		}
604		break;
605	}
606
607	return (error);
608}
609
610static int
611tls_server(const char *lstaddr, void **ctxp)
612{
613	struct proto_conn *tcp;
614	struct tls_ctx *tlsctx;
615	char *laddr;
616	int error;
617
618	if (strncmp(lstaddr, "tls://", 6) != 0)
619		return (-1);
620
621	tlsctx = malloc(sizeof(*tlsctx));
622	if (tlsctx == NULL) {
623		pjdlog_warning("Unable to allocate memory.");
624		return (ENOMEM);
625	}
626
627	laddr = strdup(lstaddr);
628	if (laddr == NULL) {
629		free(tlsctx);
630		pjdlog_warning("Unable to allocate memory.");
631		return (ENOMEM);
632	}
633	bcopy("tcp://", laddr, 6);
634
635	if (proto_server(laddr, &tcp) == -1) {
636		error = errno;
637		free(tlsctx);
638		free(laddr);
639		return (error);
640	}
641	free(laddr);
642
643	tlsctx->tls_sock = NULL;
644	tlsctx->tls_tcp = tcp;
645	tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
646	tlsctx->tls_wait_called = true;
647	tlsctx->tls_magic = TLS_CTX_MAGIC;
648	*ctxp = tlsctx;
649
650	return (0);
651}
652
653static void
654tls_exec_server(const char *user, int startfd, const char *privkey,
655    const char *cert, int debuglevel)
656{
657	SSL_CTX *sslctx;
658	SSL *ssl;
659	int sockfd, tcpfd, ret;
660
661	pjdlog_debug_set(debuglevel);
662	pjdlog_prefix_set("[TLS sandbox] (server) ");
663#ifdef HAVE_SETPROCTITLE
664	setproctitle("[TLS sandbox] (server) ");
665#endif
666
667	sockfd = startfd;
668	tcpfd = startfd + 1;
669
670	SSL_load_error_strings();
671	SSL_library_init();
672
673	sslctx = SSL_CTX_new(TLSv1_server_method());
674	if (sslctx == NULL)
675		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
676
677	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
678
679	ssl = SSL_new(sslctx);
680	if (ssl == NULL)
681		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
682
683	if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
684		ssl_log_errors();
685		pjdlog_exitx(EX_CONFIG,
686		    "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
687	}
688
689	if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
690		ssl_log_errors();
691		pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
692		    cert);
693	}
694
695	if (sandbox(user, true, "proto_tls server") != 0)
696		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
697	pjdlog_debug(1, "Privileges successfully dropped.");
698
699	nonblock(sockfd);
700	nonblock(tcpfd);
701
702	if (SSL_set_fd(ssl, tcpfd) != 1)
703		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
704
705	ret = SSL_accept(ssl);
706	ssl_check_error(ssl, ret);
707
708	tls_loop(sockfd, ssl);
709}
710
711static void
712tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
713{
714	int startfd, sockfd, tcpfd, safefd;
715	char *startfdstr, *debugstr;
716
717	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
718		startfd = 3;
719	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
720		startfd = 0;
721
722	/* Declare that we are receiver. */
723	proto_send(sock, NULL, 0);
724
725	sockfd = proto_descriptor(sock);
726	tcpfd = proto_descriptor(tcp);
727
728	safefd = MAX(sockfd, tcpfd);
729	safefd = MAX(safefd, startfd);
730	safefd++;
731
732	/* Move sockfd and tcpfd to safe numbers first. */
733	if (dup2(sockfd, safefd) == -1)
734		pjdlog_exit(EX_OSERR, "dup2() failed");
735	proto_close(sock);
736	sockfd = safefd;
737	if (dup2(tcpfd, safefd + 1) == -1)
738		pjdlog_exit(EX_OSERR, "dup2() failed");
739	proto_close(tcp);
740	tcpfd = safefd + 1;
741
742	/* Move socketpair descriptor to descriptor number startfd. */
743	if (dup2(sockfd, startfd) == -1)
744		pjdlog_exit(EX_OSERR, "dup2() failed");
745	(void)close(sockfd);
746	/* Move tcp descriptor to descriptor number startfd + 1. */
747	if (dup2(tcpfd, startfd + 1) == -1)
748		pjdlog_exit(EX_OSERR, "dup2() failed");
749	(void)close(tcpfd);
750
751	closefrom(startfd + 2);
752
753	/*
754	 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
755	 * have been cleared on dup2(), but better be safe than sorry.
756	 */
757	if (fcntl(startfd, F_SETFD, 0) == -1)
758		pjdlog_exit(EX_OSERR, "fcntl() failed");
759	if (fcntl(startfd + 1, F_SETFD, 0) == -1)
760		pjdlog_exit(EX_OSERR, "fcntl() failed");
761
762	if (asprintf(&startfdstr, "%d", startfd) == -1)
763		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
764	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
765		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
766
767	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
768	    proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
769	    proto_get("tls:certfile"), debugstr, NULL);
770	pjdlog_exit(EX_SOFTWARE, "execl() failed");
771}
772
773static int
774tls_accept(void *ctx, void **newctxp)
775{
776	struct tls_ctx *tlsctx = ctx;
777	struct tls_ctx *newtlsctx;
778	struct proto_conn *sock, *tcp;
779	pid_t pid;
780	int error;
781
782	PJDLOG_ASSERT(tlsctx != NULL);
783	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
784	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
785
786	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
787		return (errno);
788
789	/* Accept TCP connection. */
790	if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
791		error = errno;
792		proto_close(sock);
793		return (error);
794	}
795
796	pid = fork();
797	switch (pid) {
798	case -1:
799		/* Failure. */
800		error = errno;
801		proto_close(sock);
802		return (error);
803	case 0:
804		/* Child. */
805		pjdlog_prefix_set("[TLS sandbox] (server) ");
806#ifdef HAVE_SETPROCTITLE
807		setproctitle("[TLS sandbox] (server) ");
808#endif
809		/* Close listen socket. */
810		proto_close(tlsctx->tls_tcp);
811		tls_call_exec_server(sock, tcp);
812		/* NOTREACHED */
813		PJDLOG_ABORT("Unreachable.");
814	default:
815		/* Parent. */
816		newtlsctx = calloc(1, sizeof(*tlsctx));
817		if (newtlsctx == NULL) {
818			error = errno;
819			proto_close(sock);
820			proto_close(tcp);
821			(void)kill(pid, SIGKILL);
822			return (error);
823		}
824		proto_local_address(tcp, newtlsctx->tls_laddr,
825		    sizeof(newtlsctx->tls_laddr));
826		PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
827		bcopy("tls://", newtlsctx->tls_laddr, 6);
828		*strrchr(newtlsctx->tls_laddr, ':') = '\0';
829		proto_remote_address(tcp, newtlsctx->tls_raddr,
830		    sizeof(newtlsctx->tls_raddr));
831		PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
832		bcopy("tls://", newtlsctx->tls_raddr, 6);
833		*strrchr(newtlsctx->tls_raddr, ':') = '\0';
834		proto_close(tcp);
835		proto_recv(sock, NULL, 0);
836		newtlsctx->tls_sock = sock;
837		newtlsctx->tls_tcp = NULL;
838		newtlsctx->tls_wait_called = true;
839		newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
840		newtlsctx->tls_magic = TLS_CTX_MAGIC;
841		*newctxp = newtlsctx;
842		return (0);
843	}
844}
845
846static int
847tls_wrap(int fd, bool client, void **ctxp)
848{
849	struct tls_ctx *tlsctx;
850	struct proto_conn *sock;
851	int error;
852
853	tlsctx = calloc(1, sizeof(*tlsctx));
854	if (tlsctx == NULL)
855		return (errno);
856
857	if (proto_wrap("socketpair", client, fd, &sock) == -1) {
858		error = errno;
859		free(tlsctx);
860		return (error);
861	}
862
863	tlsctx->tls_sock = sock;
864	tlsctx->tls_tcp = NULL;
865	tlsctx->tls_wait_called = (client ? false : true);
866	tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
867	tlsctx->tls_magic = TLS_CTX_MAGIC;
868	*ctxp = tlsctx;
869
870	return (0);
871}
872
873static int
874tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
875{
876	struct tls_ctx *tlsctx = ctx;
877
878	PJDLOG_ASSERT(tlsctx != NULL);
879	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
880	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
881	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
882	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
883	PJDLOG_ASSERT(tlsctx->tls_wait_called);
884	PJDLOG_ASSERT(fd == -1);
885
886	if (proto_send(tlsctx->tls_sock, data, size) == -1)
887		return (errno);
888
889	return (0);
890}
891
892static int
893tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
894{
895	struct tls_ctx *tlsctx = ctx;
896
897	PJDLOG_ASSERT(tlsctx != NULL);
898	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
899	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
900	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
901	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
902	PJDLOG_ASSERT(tlsctx->tls_wait_called);
903	PJDLOG_ASSERT(fdp == NULL);
904
905	if (proto_recv(tlsctx->tls_sock, data, size) == -1)
906		return (errno);
907
908	return (0);
909}
910
911static int
912tls_descriptor(const void *ctx)
913{
914	const struct tls_ctx *tlsctx = ctx;
915
916	PJDLOG_ASSERT(tlsctx != NULL);
917	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
918
919	switch (tlsctx->tls_side) {
920	case TLS_SIDE_CLIENT:
921	case TLS_SIDE_SERVER_WORK:
922		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
923
924		return (proto_descriptor(tlsctx->tls_sock));
925	case TLS_SIDE_SERVER_LISTEN:
926		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
927
928		return (proto_descriptor(tlsctx->tls_tcp));
929	default:
930		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
931	}
932}
933
934static bool
935tcp_address_match(const void *ctx, const char *addr)
936{
937	const struct tls_ctx *tlsctx = ctx;
938
939	PJDLOG_ASSERT(tlsctx != NULL);
940	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
941
942	return (strcmp(tlsctx->tls_raddr, addr) == 0);
943}
944
945static void
946tls_local_address(const void *ctx, char *addr, size_t size)
947{
948	const struct tls_ctx *tlsctx = ctx;
949
950	PJDLOG_ASSERT(tlsctx != NULL);
951	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
952	PJDLOG_ASSERT(tlsctx->tls_wait_called);
953
954	switch (tlsctx->tls_side) {
955	case TLS_SIDE_CLIENT:
956		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
957
958		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
959		break;
960	case TLS_SIDE_SERVER_WORK:
961		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
962
963		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
964		break;
965	case TLS_SIDE_SERVER_LISTEN:
966		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
967
968		proto_local_address(tlsctx->tls_tcp, addr, size);
969		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
970		/* Replace tcp:// prefix with tls:// */
971		bcopy("tls://", addr, 6);
972		break;
973	default:
974		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
975	}
976}
977
978static void
979tls_remote_address(const void *ctx, char *addr, size_t size)
980{
981	const struct tls_ctx *tlsctx = ctx;
982
983	PJDLOG_ASSERT(tlsctx != NULL);
984	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
985	PJDLOG_ASSERT(tlsctx->tls_wait_called);
986
987	switch (tlsctx->tls_side) {
988	case TLS_SIDE_CLIENT:
989		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
990
991		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
992		break;
993	case TLS_SIDE_SERVER_WORK:
994		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
995
996		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
997		break;
998	case TLS_SIDE_SERVER_LISTEN:
999		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
1000
1001		proto_remote_address(tlsctx->tls_tcp, addr, size);
1002		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
1003		/* Replace tcp:// prefix with tls:// */
1004		bcopy("tls://", addr, 6);
1005		break;
1006	default:
1007		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1008	}
1009}
1010
1011static void
1012tls_close(void *ctx)
1013{
1014	struct tls_ctx *tlsctx = ctx;
1015
1016	PJDLOG_ASSERT(tlsctx != NULL);
1017	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1018
1019	if (tlsctx->tls_sock != NULL) {
1020		proto_close(tlsctx->tls_sock);
1021		tlsctx->tls_sock = NULL;
1022	}
1023	if (tlsctx->tls_tcp != NULL) {
1024		proto_close(tlsctx->tls_tcp);
1025		tlsctx->tls_tcp = NULL;
1026	}
1027	tlsctx->tls_side = 0;
1028	tlsctx->tls_magic = 0;
1029	free(tlsctx);
1030}
1031
1032static int
1033tls_exec(int argc, char *argv[])
1034{
1035
1036	PJDLOG_ASSERT(argc > 3);
1037	PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1038
1039	pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1040
1041	if (strcmp(argv[2], "client") == 0) {
1042		if (argc != 10)
1043			return (EINVAL);
1044		tls_exec_client(argv[1], atoi(argv[3]),
1045		    argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1046		    argv[7], atoi(argv[8]), atoi(argv[9]));
1047	} else if (strcmp(argv[2], "server") == 0) {
1048		if (argc != 7)
1049			return (EINVAL);
1050		tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1051		    atoi(argv[6]));
1052	}
1053	return (EINVAL);
1054}
1055
1056static struct proto tls_proto = {
1057	.prt_name = "tls",
1058	.prt_connect = tls_connect,
1059	.prt_connect_wait = tls_connect_wait,
1060	.prt_server = tls_server,
1061	.prt_accept = tls_accept,
1062	.prt_wrap = tls_wrap,
1063	.prt_send = tls_send,
1064	.prt_recv = tls_recv,
1065	.prt_descriptor = tls_descriptor,
1066	.prt_address_match = tcp_address_match,
1067	.prt_local_address = tls_local_address,
1068	.prt_remote_address = tls_remote_address,
1069	.prt_close = tls_close,
1070	.prt_exec = tls_exec
1071};
1072
1073static __constructor void
1074tls_ctor(void)
1075{
1076
1077	proto_register(&tls_proto, false);
1078}
1079