1/*-
2 * Copyright (c) 2011 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 *
29 * $P4: //depot/projects/trustedbsd/openbsm/bin/auditdistd/proto_tls.c#2 $
30 */
31
32#include <config/config.h>
33
34#include <sys/param.h>	/* MAXHOSTNAMELEN */
35#include <sys/socket.h>
36
37#include <arpa/inet.h>
38
39#include <netinet/in.h>
40#include <netinet/tcp.h>
41
42#include <errno.h>
43#include <fcntl.h>
44#include <netdb.h>
45#include <signal.h>
46#include <stdbool.h>
47#include <stdint.h>
48#include <stdio.h>
49#include <string.h>
50#include <unistd.h>
51
52#include <openssl/err.h>
53#include <openssl/ssl.h>
54
55#include <compat/compat.h>
56#ifndef HAVE_CLOSEFROM
57#include <compat/closefrom.h>
58#endif
59#ifndef HAVE_STRLCPY
60#include <compat/strlcpy.h>
61#endif
62
63#include "pjdlog.h"
64#include "proto_impl.h"
65#include "sandbox.h"
66#include "subr.h"
67
68#define	TLS_CTX_MAGIC	0x715c7
69struct tls_ctx {
70	int		tls_magic;
71	struct proto_conn *tls_sock;
72	struct proto_conn *tls_tcp;
73	char		tls_laddr[256];
74	char		tls_raddr[256];
75	int		tls_side;
76#define	TLS_SIDE_CLIENT		0
77#define	TLS_SIDE_SERVER_LISTEN	1
78#define	TLS_SIDE_SERVER_WORK	2
79	bool		tls_wait_called;
80};
81
82#define	TLS_DEFAULT_TIMEOUT	30
83
84static int tls_connect_wait(void *ctx, int timeout);
85static void tls_close(void *ctx);
86
87static void
88block(int fd)
89{
90	int flags;
91
92	flags = fcntl(fd, F_GETFL);
93	if (flags == -1)
94		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
95	flags &= ~O_NONBLOCK;
96	if (fcntl(fd, F_SETFL, flags) == -1)
97		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
98}
99
100static void
101nonblock(int fd)
102{
103	int flags;
104
105	flags = fcntl(fd, F_GETFL);
106	if (flags == -1)
107		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
108	flags |= O_NONBLOCK;
109	if (fcntl(fd, F_SETFL, flags) == -1)
110		pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
111}
112
113static int
114wait_for_fd(int fd, int timeout)
115{
116	struct timeval tv;
117	fd_set fdset;
118	int error, ret;
119
120	error = 0;
121
122	for (;;) {
123		FD_ZERO(&fdset);
124		FD_SET(fd, &fdset);
125
126		tv.tv_sec = timeout;
127		tv.tv_usec = 0;
128
129		ret = select(fd + 1, NULL, &fdset, NULL,
130		    timeout == -1 ? NULL : &tv);
131		if (ret == 0) {
132			error = ETIMEDOUT;
133			break;
134		} else if (ret == -1) {
135			if (errno == EINTR)
136				continue;
137			error = errno;
138			break;
139		}
140		PJDLOG_ASSERT(ret > 0);
141		PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
142		break;
143	}
144
145	return (error);
146}
147
148static void
149ssl_log_errors(void)
150{
151	unsigned long error;
152
153	while ((error = ERR_get_error()) != 0)
154		pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
155}
156
157static int
158ssl_check_error(SSL *ssl, int ret)
159{
160	int error;
161
162	error = SSL_get_error(ssl, ret);
163
164	switch (error) {
165	case SSL_ERROR_NONE:
166		return (0);
167	case SSL_ERROR_WANT_READ:
168		pjdlog_debug(2, "SSL_ERROR_WANT_READ");
169		return (-1);
170	case SSL_ERROR_WANT_WRITE:
171		pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
172		return (-1);
173	case SSL_ERROR_ZERO_RETURN:
174		pjdlog_exitx(EX_OK, "Connection closed.");
175	case SSL_ERROR_SYSCALL:
176		ssl_log_errors();
177		pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
178	case SSL_ERROR_SSL:
179		ssl_log_errors();
180		pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
181	default:
182		ssl_log_errors();
183		pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
184	}
185}
186
187static void
188tcp_recv_ssl_send(int recvfd, SSL *sendssl)
189{
190	static unsigned char buf[65536];
191	ssize_t tcpdone;
192	int sendfd, ssldone;
193
194	sendfd = SSL_get_fd(sendssl);
195	PJDLOG_ASSERT(sendfd >= 0);
196	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
197	for (;;) {
198		tcpdone = recv(recvfd, buf, sizeof(buf), 0);
199		pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
200		if (tcpdone == 0) {
201			pjdlog_debug(1, "Connection terminated.");
202			exit(0);
203		} else if (tcpdone == -1) {
204			if (errno == EINTR)
205				continue;
206			else if (errno == EAGAIN)
207				break;
208			pjdlog_exit(EX_TEMPFAIL, "recv() failed");
209		}
210		for (;;) {
211			ssldone = SSL_write(sendssl, buf, (int)tcpdone);
212			pjdlog_debug(2, "%s: send() returned %d", __func__,
213			    ssldone);
214			if (ssl_check_error(sendssl, ssldone) == -1) {
215				(void)wait_for_fd(sendfd, -1);
216				continue;
217			}
218			PJDLOG_ASSERT(ssldone == tcpdone);
219			break;
220		}
221	}
222	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
223}
224
225static void
226ssl_recv_tcp_send(SSL *recvssl, int sendfd)
227{
228	static unsigned char buf[65536];
229	unsigned char *ptr;
230	ssize_t tcpdone;
231	size_t todo;
232	int recvfd, ssldone;
233
234	recvfd = SSL_get_fd(recvssl);
235	PJDLOG_ASSERT(recvfd >= 0);
236	pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
237	for (;;) {
238		ssldone = SSL_read(recvssl, buf, sizeof(buf));
239		pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
240		    ssldone);
241		if (ssl_check_error(recvssl, ssldone) == -1)
242			break;
243		todo = (size_t)ssldone;
244		ptr = buf;
245		do {
246			tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
247			pjdlog_debug(2, "%s: send() returned %zd", __func__,
248			    tcpdone);
249			if (tcpdone == 0) {
250				pjdlog_debug(1, "Connection terminated.");
251				exit(0);
252			} else if (tcpdone == -1) {
253				if (errno == EINTR || errno == ENOBUFS)
254					continue;
255				if (errno == EAGAIN) {
256					(void)wait_for_fd(sendfd, -1);
257					continue;
258				}
259				pjdlog_exit(EX_TEMPFAIL, "send() failed");
260			}
261			todo -= tcpdone;
262			ptr += tcpdone;
263		} while (todo > 0);
264	}
265	pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
266}
267
268static void
269tls_loop(int sockfd, SSL *tcpssl)
270{
271	fd_set fds;
272	int maxfd, tcpfd;
273
274	tcpfd = SSL_get_fd(tcpssl);
275	PJDLOG_ASSERT(tcpfd >= 0);
276
277	for (;;) {
278		FD_ZERO(&fds);
279		FD_SET(sockfd, &fds);
280		FD_SET(tcpfd, &fds);
281		maxfd = MAX(sockfd, tcpfd);
282
283		PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
284		if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
285			if (errno == EINTR)
286				continue;
287			pjdlog_exit(EX_TEMPFAIL, "select() failed");
288		}
289		if (FD_ISSET(sockfd, &fds))
290			tcp_recv_ssl_send(sockfd, tcpssl);
291		if (FD_ISSET(tcpfd, &fds))
292			ssl_recv_tcp_send(tcpssl, sockfd);
293	}
294}
295
296static void
297tls_certificate_verify(SSL *ssl, const char *fingerprint)
298{
299	unsigned char md[EVP_MAX_MD_SIZE];
300	char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
301	char *mdstrp;
302	unsigned int i, mdsize;
303	X509 *cert;
304
305	if (fingerprint[0] == '\0') {
306		pjdlog_debug(1, "No fingerprint verification requested.");
307		return;
308	}
309
310	cert = SSL_get_peer_certificate(ssl);
311	if (cert == NULL)
312		pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
313
314	if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
315		pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
316	PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
317
318	X509_free(cert);
319
320	(void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
321	mdstrp = mdstr + strlen(mdstr);
322	for (i = 0; i < mdsize; i++) {
323		PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
324		(void)sprintf(mdstrp, "%02hhX:", md[i]);
325		mdstrp += 3;
326	}
327	/* Clear last colon. */
328	mdstrp[-1] = '\0';
329	if (strcasecmp(mdstr, fingerprint) != 0) {
330		pjdlog_exitx(EX_NOPERM,
331		    "Finger print doesn't match. Received \"%s\", expected \"%s\"",
332		    mdstr, fingerprint);
333	}
334}
335
336static void
337tls_exec_client(const char *user, int startfd, const char *srcaddr,
338    const char *dstaddr, const char *fingerprint, const char *defport,
339    int timeout, int debuglevel)
340{
341	struct proto_conn *tcp;
342	char *saddr, *daddr;
343	SSL_CTX *sslctx;
344	SSL *ssl;
345	long ret;
346	int sockfd, tcpfd;
347	uint8_t connected;
348
349	pjdlog_debug_set(debuglevel);
350	pjdlog_prefix_set("[TLS sandbox] (client) ");
351#ifdef HAVE_SETPROCTITLE
352	setproctitle("[TLS sandbox] (client) ");
353#endif
354	proto_set("tcp:port", defport);
355
356	sockfd = startfd;
357
358	/* Change tls:// to tcp://. */
359	if (srcaddr == NULL) {
360		saddr = NULL;
361	} else {
362		saddr = strdup(srcaddr);
363		if (saddr == NULL)
364			pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
365		bcopy("tcp://", saddr, 6);
366	}
367	daddr = strdup(dstaddr);
368	if (daddr == NULL)
369		pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
370	bcopy("tcp://", daddr, 6);
371
372	/* Establish TCP connection. */
373	if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
374		exit(EX_TEMPFAIL);
375
376	SSL_load_error_strings();
377	SSL_library_init();
378
379	/*
380	 * TODO: On FreeBSD we could move this below sandbox() once libc and
381	 *       libcrypto use sysctl kern.arandom to obtain random data
382	 *       instead of /dev/urandom and friends.
383	 */
384	sslctx = SSL_CTX_new(TLSv1_client_method());
385	if (sslctx == NULL)
386		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
387
388	if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
389		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
390	pjdlog_debug(1, "Privileges successfully dropped.");
391
392	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
393
394	/* Load CA certs. */
395	/* TODO */
396	//SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
397
398	ssl = SSL_new(sslctx);
399	if (ssl == NULL)
400		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
401
402	tcpfd = proto_descriptor(tcp);
403
404	block(tcpfd);
405
406	if (SSL_set_fd(ssl, tcpfd) != 1)
407		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
408
409	ret = SSL_connect(ssl);
410	ssl_check_error(ssl, (int)ret);
411
412	nonblock(sockfd);
413	nonblock(tcpfd);
414
415	tls_certificate_verify(ssl, fingerprint);
416
417	/*
418	 * The following byte is send to make proto_connect_wait() to work.
419	 */
420	connected = 1;
421	for (;;) {
422		switch (send(sockfd, &connected, sizeof(connected), 0)) {
423		case -1:
424			if (errno == EINTR || errno == ENOBUFS)
425				continue;
426			if (errno == EAGAIN) {
427				(void)wait_for_fd(sockfd, -1);
428				continue;
429			}
430			pjdlog_exit(EX_TEMPFAIL, "send() failed");
431		case 0:
432			pjdlog_debug(1, "Connection terminated.");
433			exit(0);
434		case 1:
435			break;
436		}
437		break;
438	}
439
440	tls_loop(sockfd, ssl);
441}
442
443static void
444tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
445    const char *dstaddr, int timeout)
446{
447	char *timeoutstr, *startfdstr, *debugstr;
448	int startfd;
449
450	/* Declare that we are receiver. */
451	proto_recv(sock, NULL, 0);
452
453	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
454		startfd = 3;
455	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
456		startfd = 0;
457
458	if (proto_descriptor(sock) != startfd) {
459		/* Move socketpair descriptor to descriptor number startfd. */
460		if (dup2(proto_descriptor(sock), startfd) == -1)
461			pjdlog_exit(EX_OSERR, "dup2() failed");
462		proto_close(sock);
463	} else {
464		/*
465		 * The FD_CLOEXEC is cleared by dup2(2), so when we not
466		 * call it, we have to clear it by hand in case it is set.
467		 */
468		if (fcntl(startfd, F_SETFD, 0) == -1)
469			pjdlog_exit(EX_OSERR, "fcntl() failed");
470	}
471
472	closefrom(startfd + 1);
473
474	if (asprintf(&startfdstr, "%d", startfd) == -1)
475		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
476	if (timeout == -1)
477		timeout = TLS_DEFAULT_TIMEOUT;
478	if (asprintf(&timeoutstr, "%d", timeout) == -1)
479		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
480	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
481		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
482
483	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
484	    proto_get("user"), "client", startfdstr,
485	    srcaddr == NULL ? "" : srcaddr, dstaddr,
486	    proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
487	    debugstr, NULL);
488	pjdlog_exit(EX_SOFTWARE, "execl() failed");
489}
490
491static int
492tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
493{
494	struct tls_ctx *tlsctx;
495	struct proto_conn *sock;
496	pid_t pid;
497	int error;
498
499	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
500	PJDLOG_ASSERT(dstaddr != NULL);
501	PJDLOG_ASSERT(timeout >= -1);
502	PJDLOG_ASSERT(ctxp != NULL);
503
504	if (strncmp(dstaddr, "tls://", 6) != 0)
505		return (-1);
506	if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
507		return (-1);
508
509	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
510		return (errno);
511
512#if 0
513	/*
514	 * We use rfork() with the following flags to disable SIGCHLD
515	 * delivery upon the sandbox process exit.
516	 */
517	pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
518#else
519	/*
520	 * We don't use rfork() to be able to log information about sandbox
521	 * process exiting.
522	 */
523	pid = fork();
524#endif
525	switch (pid) {
526	case -1:
527		/* Failure. */
528		error = errno;
529		proto_close(sock);
530		return (error);
531	case 0:
532		/* Child. */
533		pjdlog_prefix_set("[TLS sandbox] (client) ");
534#ifdef HAVE_SETPROCTITLE
535		setproctitle("[TLS sandbox] (client) ");
536#endif
537		tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
538		/* NOTREACHED */
539	default:
540		/* Parent. */
541		tlsctx = calloc(1, sizeof(*tlsctx));
542		if (tlsctx == NULL) {
543			error = errno;
544			proto_close(sock);
545			(void)kill(pid, SIGKILL);
546			return (error);
547		}
548		proto_send(sock, NULL, 0);
549		tlsctx->tls_sock = sock;
550		tlsctx->tls_tcp = NULL;
551		tlsctx->tls_side = TLS_SIDE_CLIENT;
552		tlsctx->tls_wait_called = false;
553		tlsctx->tls_magic = TLS_CTX_MAGIC;
554		if (timeout >= 0) {
555			error = tls_connect_wait(tlsctx, timeout);
556			if (error != 0) {
557				(void)kill(pid, SIGKILL);
558				tls_close(tlsctx);
559				return (error);
560			}
561		}
562		*ctxp = tlsctx;
563		return (0);
564	}
565}
566
567static int
568tls_connect_wait(void *ctx, int timeout)
569{
570	struct tls_ctx *tlsctx = ctx;
571	int error, sockfd;
572	uint8_t connected;
573
574	PJDLOG_ASSERT(tlsctx != NULL);
575	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
576	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
577	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
578	PJDLOG_ASSERT(!tlsctx->tls_wait_called);
579	PJDLOG_ASSERT(timeout >= 0);
580
581	sockfd = proto_descriptor(tlsctx->tls_sock);
582	error = wait_for_fd(sockfd, timeout);
583	if (error != 0)
584		return (error);
585
586	for (;;) {
587		switch (recv(sockfd, &connected, sizeof(connected),
588		    MSG_WAITALL)) {
589		case -1:
590			if (errno == EINTR || errno == ENOBUFS)
591				continue;
592			error = errno;
593			break;
594		case 0:
595			pjdlog_debug(1, "Connection terminated.");
596			error = ENOTCONN;
597			break;
598		case 1:
599			tlsctx->tls_wait_called = true;
600			break;
601		}
602		break;
603	}
604
605	return (error);
606}
607
608static int
609tls_server(const char *lstaddr, void **ctxp)
610{
611	struct proto_conn *tcp;
612	struct tls_ctx *tlsctx;
613	char *laddr;
614	int error;
615
616	if (strncmp(lstaddr, "tls://", 6) != 0)
617		return (-1);
618
619	tlsctx = malloc(sizeof(*tlsctx));
620	if (tlsctx == NULL) {
621		pjdlog_warning("Unable to allocate memory.");
622		return (ENOMEM);
623	}
624
625	laddr = strdup(lstaddr);
626	if (laddr == NULL) {
627		free(tlsctx);
628		pjdlog_warning("Unable to allocate memory.");
629		return (ENOMEM);
630	}
631	bcopy("tcp://", laddr, 6);
632
633	if (proto_server(laddr, &tcp) == -1) {
634		error = errno;
635		free(tlsctx);
636		free(laddr);
637		return (error);
638	}
639	free(laddr);
640
641	tlsctx->tls_sock = NULL;
642	tlsctx->tls_tcp = tcp;
643	tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
644	tlsctx->tls_wait_called = true;
645	tlsctx->tls_magic = TLS_CTX_MAGIC;
646	*ctxp = tlsctx;
647
648	return (0);
649}
650
651static void
652tls_exec_server(const char *user, int startfd, const char *privkey,
653    const char *cert, int debuglevel)
654{
655	SSL_CTX *sslctx;
656	SSL *ssl;
657	int sockfd, tcpfd, ret;
658
659	pjdlog_debug_set(debuglevel);
660	pjdlog_prefix_set("[TLS sandbox] (server) ");
661#ifdef HAVE_SETPROCTITLE
662	setproctitle("[TLS sandbox] (server) ");
663#endif
664
665	sockfd = startfd;
666	tcpfd = startfd + 1;
667
668	SSL_load_error_strings();
669	SSL_library_init();
670
671	sslctx = SSL_CTX_new(TLSv1_server_method());
672	if (sslctx == NULL)
673		pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
674
675	SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
676
677	ssl = SSL_new(sslctx);
678	if (ssl == NULL)
679		pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
680
681	if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
682		ssl_log_errors();
683		pjdlog_exitx(EX_CONFIG,
684		    "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
685	}
686
687	if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
688		ssl_log_errors();
689		pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
690		    cert);
691	}
692
693	if (sandbox(user, true, "proto_tls server") != 0)
694		pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
695	pjdlog_debug(1, "Privileges successfully dropped.");
696
697	nonblock(sockfd);
698	nonblock(tcpfd);
699
700	if (SSL_set_fd(ssl, tcpfd) != 1)
701		pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
702
703	ret = SSL_accept(ssl);
704	ssl_check_error(ssl, ret);
705
706	tls_loop(sockfd, ssl);
707}
708
709static void
710tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
711{
712	int startfd, sockfd, tcpfd, safefd;
713	char *startfdstr, *debugstr;
714
715	if (pjdlog_mode_get() == PJDLOG_MODE_STD)
716		startfd = 3;
717	else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
718		startfd = 0;
719
720	/* Declare that we are receiver. */
721	proto_send(sock, NULL, 0);
722
723	sockfd = proto_descriptor(sock);
724	tcpfd = proto_descriptor(tcp);
725
726	safefd = MAX(sockfd, tcpfd);
727	safefd = MAX(safefd, startfd);
728	safefd++;
729
730	/* Move sockfd and tcpfd to safe numbers first. */
731	if (dup2(sockfd, safefd) == -1)
732		pjdlog_exit(EX_OSERR, "dup2() failed");
733	proto_close(sock);
734	sockfd = safefd;
735	if (dup2(tcpfd, safefd + 1) == -1)
736		pjdlog_exit(EX_OSERR, "dup2() failed");
737	proto_close(tcp);
738	tcpfd = safefd + 1;
739
740	/* Move socketpair descriptor to descriptor number startfd. */
741	if (dup2(sockfd, startfd) == -1)
742		pjdlog_exit(EX_OSERR, "dup2() failed");
743	(void)close(sockfd);
744	/* Move tcp descriptor to descriptor number startfd + 1. */
745	if (dup2(tcpfd, startfd + 1) == -1)
746		pjdlog_exit(EX_OSERR, "dup2() failed");
747	(void)close(tcpfd);
748
749	closefrom(startfd + 2);
750
751	/*
752	 * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
753	 * have been cleared on dup2(), but better be safe than sorry.
754	 */
755	if (fcntl(startfd, F_SETFD, 0) == -1)
756		pjdlog_exit(EX_OSERR, "fcntl() failed");
757	if (fcntl(startfd + 1, F_SETFD, 0) == -1)
758		pjdlog_exit(EX_OSERR, "fcntl() failed");
759
760	if (asprintf(&startfdstr, "%d", startfd) == -1)
761		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
762	if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
763		pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
764
765	execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
766	    proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
767	    proto_get("tls:certfile"), debugstr, NULL);
768	pjdlog_exit(EX_SOFTWARE, "execl() failed");
769}
770
771static int
772tls_accept(void *ctx, void **newctxp)
773{
774	struct tls_ctx *tlsctx = ctx;
775	struct tls_ctx *newtlsctx;
776	struct proto_conn *sock, *tcp;
777	pid_t pid;
778	int error;
779
780	PJDLOG_ASSERT(tlsctx != NULL);
781	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
782	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
783
784	if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
785		return (errno);
786
787	/* Accept TCP connection. */
788	if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
789		error = errno;
790		proto_close(sock);
791		return (error);
792	}
793
794	pid = fork();
795	switch (pid) {
796	case -1:
797		/* Failure. */
798		error = errno;
799		proto_close(sock);
800		return (error);
801	case 0:
802		/* Child. */
803		pjdlog_prefix_set("[TLS sandbox] (server) ");
804#ifdef HAVE_SETPROCTITLE
805		setproctitle("[TLS sandbox] (server) ");
806#endif
807		/* Close listen socket. */
808		proto_close(tlsctx->tls_tcp);
809		tls_call_exec_server(sock, tcp);
810		/* NOTREACHED */
811		PJDLOG_ABORT("Unreachable.");
812	default:
813		/* Parent. */
814		newtlsctx = calloc(1, sizeof(*tlsctx));
815		if (newtlsctx == NULL) {
816			error = errno;
817			proto_close(sock);
818			proto_close(tcp);
819			(void)kill(pid, SIGKILL);
820			return (error);
821		}
822		proto_local_address(tcp, newtlsctx->tls_laddr,
823		    sizeof(newtlsctx->tls_laddr));
824		PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
825		bcopy("tls://", newtlsctx->tls_laddr, 6);
826		*strrchr(newtlsctx->tls_laddr, ':') = '\0';
827		proto_remote_address(tcp, newtlsctx->tls_raddr,
828		    sizeof(newtlsctx->tls_raddr));
829		PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
830		bcopy("tls://", newtlsctx->tls_raddr, 6);
831		*strrchr(newtlsctx->tls_raddr, ':') = '\0';
832		proto_close(tcp);
833		proto_recv(sock, NULL, 0);
834		newtlsctx->tls_sock = sock;
835		newtlsctx->tls_tcp = NULL;
836		newtlsctx->tls_wait_called = true;
837		newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
838		newtlsctx->tls_magic = TLS_CTX_MAGIC;
839		*newctxp = newtlsctx;
840		return (0);
841	}
842}
843
844static int
845tls_wrap(int fd, bool client, void **ctxp)
846{
847	struct tls_ctx *tlsctx;
848	struct proto_conn *sock;
849	int error;
850
851	tlsctx = calloc(1, sizeof(*tlsctx));
852	if (tlsctx == NULL)
853		return (errno);
854
855	if (proto_wrap("socketpair", client, fd, &sock) == -1) {
856		error = errno;
857		free(tlsctx);
858		return (error);
859	}
860
861	tlsctx->tls_sock = sock;
862	tlsctx->tls_tcp = NULL;
863	tlsctx->tls_wait_called = (client ? false : true);
864	tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
865	tlsctx->tls_magic = TLS_CTX_MAGIC;
866	*ctxp = tlsctx;
867
868	return (0);
869}
870
871static int
872tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
873{
874	struct tls_ctx *tlsctx = ctx;
875
876	PJDLOG_ASSERT(tlsctx != NULL);
877	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
878	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
879	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
880	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
881	PJDLOG_ASSERT(tlsctx->tls_wait_called);
882	PJDLOG_ASSERT(fd == -1);
883
884	if (proto_send(tlsctx->tls_sock, data, size) == -1)
885		return (errno);
886
887	return (0);
888}
889
890static int
891tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
892{
893	struct tls_ctx *tlsctx = ctx;
894
895	PJDLOG_ASSERT(tlsctx != NULL);
896	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
897	PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
898	    tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
899	PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
900	PJDLOG_ASSERT(tlsctx->tls_wait_called);
901	PJDLOG_ASSERT(fdp == NULL);
902
903	if (proto_recv(tlsctx->tls_sock, data, size) == -1)
904		return (errno);
905
906	return (0);
907}
908
909static int
910tls_descriptor(const void *ctx)
911{
912	const struct tls_ctx *tlsctx = ctx;
913
914	PJDLOG_ASSERT(tlsctx != NULL);
915	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
916
917	switch (tlsctx->tls_side) {
918	case TLS_SIDE_CLIENT:
919	case TLS_SIDE_SERVER_WORK:
920		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
921
922		return (proto_descriptor(tlsctx->tls_sock));
923	case TLS_SIDE_SERVER_LISTEN:
924		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
925
926		return (proto_descriptor(tlsctx->tls_tcp));
927	default:
928		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
929	}
930}
931
932static bool
933tcp_address_match(const void *ctx, const char *addr)
934{
935	const struct tls_ctx *tlsctx = ctx;
936
937	PJDLOG_ASSERT(tlsctx != NULL);
938	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
939
940	return (strcmp(tlsctx->tls_raddr, addr) == 0);
941}
942
943static void
944tls_local_address(const void *ctx, char *addr, size_t size)
945{
946	const struct tls_ctx *tlsctx = ctx;
947
948	PJDLOG_ASSERT(tlsctx != NULL);
949	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
950	PJDLOG_ASSERT(tlsctx->tls_wait_called);
951
952	switch (tlsctx->tls_side) {
953	case TLS_SIDE_CLIENT:
954		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
955
956		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
957		break;
958	case TLS_SIDE_SERVER_WORK:
959		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
960
961		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
962		break;
963	case TLS_SIDE_SERVER_LISTEN:
964		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
965
966		proto_local_address(tlsctx->tls_tcp, addr, size);
967		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
968		/* Replace tcp:// prefix with tls:// */
969		bcopy("tls://", addr, 6);
970		break;
971	default:
972		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
973	}
974}
975
976static void
977tls_remote_address(const void *ctx, char *addr, size_t size)
978{
979	const struct tls_ctx *tlsctx = ctx;
980
981	PJDLOG_ASSERT(tlsctx != NULL);
982	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
983	PJDLOG_ASSERT(tlsctx->tls_wait_called);
984
985	switch (tlsctx->tls_side) {
986	case TLS_SIDE_CLIENT:
987		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
988
989		PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
990		break;
991	case TLS_SIDE_SERVER_WORK:
992		PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
993
994		PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
995		break;
996	case TLS_SIDE_SERVER_LISTEN:
997		PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
998
999		proto_remote_address(tlsctx->tls_tcp, addr, size);
1000		PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
1001		/* Replace tcp:// prefix with tls:// */
1002		bcopy("tls://", addr, 6);
1003		break;
1004	default:
1005		PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1006	}
1007}
1008
1009static void
1010tls_close(void *ctx)
1011{
1012	struct tls_ctx *tlsctx = ctx;
1013
1014	PJDLOG_ASSERT(tlsctx != NULL);
1015	PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1016
1017	if (tlsctx->tls_sock != NULL) {
1018		proto_close(tlsctx->tls_sock);
1019		tlsctx->tls_sock = NULL;
1020	}
1021	if (tlsctx->tls_tcp != NULL) {
1022		proto_close(tlsctx->tls_tcp);
1023		tlsctx->tls_tcp = NULL;
1024	}
1025	tlsctx->tls_side = 0;
1026	tlsctx->tls_magic = 0;
1027	free(tlsctx);
1028}
1029
1030static int
1031tls_exec(int argc, char *argv[])
1032{
1033
1034	PJDLOG_ASSERT(argc > 3);
1035	PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1036
1037	pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1038
1039	if (strcmp(argv[2], "client") == 0) {
1040		if (argc != 10)
1041			return (EINVAL);
1042		tls_exec_client(argv[1], atoi(argv[3]),
1043		    argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1044		    argv[7], atoi(argv[8]), atoi(argv[9]));
1045	} else if (strcmp(argv[2], "server") == 0) {
1046		if (argc != 7)
1047			return (EINVAL);
1048		tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1049		    atoi(argv[6]));
1050	}
1051	return (EINVAL);
1052}
1053
1054static struct proto tls_proto = {
1055	.prt_name = "tls",
1056	.prt_connect = tls_connect,
1057	.prt_connect_wait = tls_connect_wait,
1058	.prt_server = tls_server,
1059	.prt_accept = tls_accept,
1060	.prt_wrap = tls_wrap,
1061	.prt_send = tls_send,
1062	.prt_recv = tls_recv,
1063	.prt_descriptor = tls_descriptor,
1064	.prt_address_match = tcp_address_match,
1065	.prt_local_address = tls_local_address,
1066	.prt_remote_address = tls_remote_address,
1067	.prt_close = tls_close,
1068	.prt_exec = tls_exec
1069};
1070
1071static __constructor void
1072tls_ctor(void)
1073{
1074
1075	proto_register(&tls_proto, false);
1076}
1077