net.c revision 1.3
1/*	$OpenBSD: net.c,v 1.3 2005/05/23 17:35:01 ho Exp $	*/
2
3/*
4 * Copyright (c) 2005 H�kan Olsson.  All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28/*
29 * This code was written under funding by Multicom Security AB.
30 */
31
32
33#include <sys/types.h>
34#include <sys/socket.h>
35#include <sys/time.h>
36#include <netinet/in.h>
37#include <arpa/inet.h>
38#include <ifaddrs.h>
39#include <netdb.h>
40
41#include <openssl/aes.h>
42#include <openssl/sha.h>
43
44#include <errno.h>
45#include <stdio.h>
46#include <stdlib.h>
47#include <string.h>
48#include <unistd.h>
49
50#include "sasyncd.h"
51#include "net.h"
52
53struct msg {
54	u_int8_t	*buf;
55	u_int32_t	 len;
56	int		 refcnt;
57};
58
59struct qmsg {
60	SIMPLEQ_ENTRY(qmsg)	next;
61	struct msg	*msg;
62};
63
64int	listen_socket;
65AES_KEY	aes_key[2];
66#define AES_IV_LEN AES_BLOCK_SIZE
67
68/* Local prototypes. */
69static u_int8_t *net_read(struct syncpeer *, u_int32_t *, u_int32_t *);
70static int	 net_set_sa(struct sockaddr *, char *, in_port_t);
71static void	 net_check_peers(void *);
72
73static void
74dump_buf(int lvl, u_int8_t *b, u_int32_t len, char *title)
75{
76	u_int32_t	i, off, blen = len*2 + 3 + strlen(title);
77	u_int8_t	*buf = calloc(1, blen);
78
79	if (!buf || cfgstate.verboselevel < lvl)
80		return;
81
82	snprintf(buf, blen, "%s:\n", title);
83	off = strlen(buf);
84	for (i = 0; i < len; i++, off+=2)
85		snprintf(buf + off, blen - off, "%02x", b[i]);
86	log_msg(lvl, "%s", buf);
87	free(buf);
88}
89
90int
91net_init(void)
92{
93	struct sockaddr_storage sa_storage;
94	struct sockaddr *sa = (struct sockaddr *)&sa_storage;
95	struct syncpeer *p;
96	char		 host[NI_MAXHOST], port[NI_MAXSERV];
97	int		 r;
98
99	/* The shared key needs to be 128, 192 or 256 bits */
100	r = strlen(cfgstate.sharedkey) << 3;
101	if (r != 128 && r != 192 && r != 256) {
102		fprintf(stderr, "Bad shared key length (%d bits), "
103		    "should be 128, 192 or 256\n", r);
104		return -1;
105	}
106
107	if (AES_set_encrypt_key(cfgstate.sharedkey, r, &aes_key[0]) ||
108	    AES_set_decrypt_key(cfgstate.sharedkey, r, &aes_key[1])) {
109		fprintf(stderr, "Bad AES shared key\n");
110		return -1;
111	}
112
113	/* Setup listening socket.  */
114	memset(&sa_storage, 0, sizeof sa_storage);
115	if (net_set_sa(sa, cfgstate.listen_on, cfgstate.listen_port)) {
116		log_msg(0, "net_init: could not find listen address (%s)",
117		    cfgstate.listen_on);
118		return -1;
119	}
120
121	listen_socket = socket(sa->sa_family, SOCK_STREAM, 0);
122	if (listen_socket < 0) {
123		perror("socket()");
124		close(listen_socket);
125		return -1;
126	}
127	r = 1;
128	if (setsockopt(listen_socket, SOL_SOCKET,
129	    cfgstate.listen_on ? SO_REUSEADDR : SO_REUSEPORT, (void *)&r,
130	    sizeof r)) {
131		perror("setsockopt()");
132		close(listen_socket);
133		return -1;
134	}
135	if (bind(listen_socket, sa, sizeof(struct sockaddr_in))) {
136		perror("bind()");
137		close(listen_socket);
138		return -1;
139	}
140	if (listen(listen_socket, 10)) {
141		perror("listen()");
142		close(listen_socket);
143		return -1;
144	}
145
146	if (getnameinfo(sa, sa->sa_len, host, sizeof host, port, sizeof port,
147		NI_NUMERICHOST | NI_NUMERICSERV))
148		log_msg(2, "listening on port %u fd %d", cfgstate.listen_port,
149		    listen_socket);
150	else
151		log_msg(2, "listening on %s port %s fd %d", host, port,
152		    listen_socket);
153
154	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
155		p->socket = -1;
156		SIMPLEQ_INIT(&p->msgs);
157	}
158
159	net_check_peers(0);
160	return 0;
161}
162
163static void
164net_enqueue(struct syncpeer *p, struct msg *m)
165{
166	struct qmsg	*qm;
167
168	if (p->socket < 0)
169		return;
170
171	qm = (struct qmsg *)malloc(sizeof *qm);
172	if (!qm) {
173		log_err("malloc()");
174		return;
175	}
176
177	memset(qm, 0, sizeof *qm);
178	qm->msg = m;
179	m->refcnt++;
180
181	SIMPLEQ_INSERT_TAIL(&p->msgs, qm, next);
182	return;
183}
184
185/*
186 * Queue a message for transmission to a particular peer,
187 * or to all peers if no peer is specified.
188 */
189int
190net_queue(struct syncpeer *p0, u_int32_t msgtype, u_int8_t *buf, u_int32_t len)
191{
192	struct syncpeer *p = p0;
193	struct msg	*m;
194	SHA_CTX		 ctx;
195	u_int8_t	 hash[SHA_DIGEST_LENGTH];
196	u_int8_t	 iv[AES_IV_LEN], tmp_iv[AES_IV_LEN];
197	u_int32_t	 v, padlen = 0;
198	int		 i, offset;
199
200	m = (struct msg *)calloc(1, sizeof *m);
201	if (!m) {
202		log_err("calloc()");
203		free(buf);
204		return -1;
205	}
206
207	/* Generate hash */
208	SHA1_Init(&ctx);
209	SHA1_Update(&ctx, buf, len);
210	SHA1_Final(hash, &ctx);
211	dump_buf(5, hash, sizeof hash, "Hash");
212
213	/* Padding required? */
214	i = len % AES_IV_LEN;
215	if (i) {
216		u_int8_t *pbuf;
217		i = AES_IV_LEN - i;
218		pbuf = realloc(buf, len + i);
219		if (!pbuf) {
220			log_err("net_queue: realloc()");
221			free(buf);
222			free(m);
223			return -1;
224		}
225		padlen = i;
226		while (i > 0)
227			pbuf[len++] = (u_int8_t)i--;
228		buf = pbuf;
229	}
230
231	/* Get random IV */
232	for (i = 0; i <= sizeof iv - sizeof v; i += sizeof v) {
233		v = arc4random();
234		memcpy(&iv[i], &v, sizeof v);
235	}
236	dump_buf(5, iv, sizeof iv, "IV");
237	memcpy(tmp_iv, iv, sizeof tmp_iv);
238
239	/* Encrypt */
240	dump_buf(5, buf, len, "Pre-enc");
241	AES_cbc_encrypt(buf, buf, len, &aes_key[0], tmp_iv, AES_ENCRYPT);
242	dump_buf(5, buf, len, "Post-enc");
243
244	/* Allocate send buffer */
245	m->len = len + sizeof iv + sizeof hash + 3 * sizeof(u_int32_t);
246	m->buf = (u_int8_t *)malloc(m->len);
247	if (!m->buf) {
248		free(m);
249		free(buf);
250		log_err("net_queue: calloc()");
251		return -1;
252	}
253	offset = 0;
254
255	/* Fill it (order must match parsing code in net_read()) */
256	v = htonl(m->len - sizeof(u_int32_t));
257	memcpy(m->buf + offset, &v, sizeof v);
258	offset += sizeof v;
259	v = htonl(msgtype);
260	memcpy(m->buf + offset, &v, sizeof v);
261	offset += sizeof v;
262	v = htonl(padlen);
263	memcpy(m->buf + offset, &v, sizeof v);
264	offset += sizeof v;
265	memcpy(m->buf + offset, hash, sizeof hash);
266	offset += sizeof hash;
267	memcpy(m->buf + offset, iv, sizeof iv);
268	offset += sizeof iv;
269	memcpy(m->buf + offset, buf, len);
270	free(buf);
271
272	if (p)
273		net_enqueue(p, m);
274	else
275		for (p = LIST_FIRST(&cfgstate.peerlist); p;
276		     p = LIST_NEXT(p, link))
277			net_enqueue(p, m);
278
279	if (!m->refcnt) {
280		free(m->buf);
281		free(m);
282	}
283
284	return 0;
285}
286
287/* Set all write pending filedescriptors. */
288int
289net_set_pending_wfds(fd_set *fds)
290{
291	struct syncpeer *p;
292	int		max_fd = -1;
293
294	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link))
295		if (p->socket > -1 && SIMPLEQ_FIRST(&p->msgs)) {
296			FD_SET(p->socket, fds);
297			if (p->socket > max_fd)
298				max_fd = p->socket;
299		}
300	return max_fd + 1;
301}
302
303/*
304 * Set readable filedescriptors. They are basically the same as for write,
305 * plus the listening socket.
306 */
307int
308net_set_rfds(fd_set *fds)
309{
310	struct syncpeer *p;
311	int		max_fd = -1;
312
313	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
314		if (p->socket > -1)
315			FD_SET(p->socket, fds);
316		if (p->socket > max_fd)
317			max_fd = p->socket;
318	}
319	FD_SET(listen_socket, fds);
320	if (listen_socket > max_fd)
321		max_fd = listen_socket;
322	return max_fd + 1;
323}
324
325void
326net_handle_messages(fd_set *fds)
327{
328	struct sockaddr_storage	sa_storage, sa_storage2;
329	struct sockaddr	*sa = (struct sockaddr *)&sa_storage;
330	struct sockaddr	*sa2 = (struct sockaddr *)&sa_storage2;
331	socklen_t	socklen;
332	struct syncpeer *p;
333	u_int8_t	*msg;
334	u_int32_t	 msgtype, msglen;
335	int		 newsock, found;
336
337	if (FD_ISSET(listen_socket, fds)) {
338		/* Accept a new incoming connection */
339		socklen = sizeof sa_storage;
340		newsock = accept(listen_socket, sa, &socklen);
341		if (newsock > -1) {
342			/* Setup the syncpeer structure */
343			found = 0;
344			for (p = LIST_FIRST(&cfgstate.peerlist); p && !found;
345			     p = LIST_NEXT(p, link)) {
346				struct sockaddr_in *sin, *sin2;
347				struct sockaddr_in6 *sin6, *sin62;
348
349				/* Match? */
350				if (net_set_sa(sa2, p->name, 0))
351					continue;
352				if (sa->sa_family != sa2->sa_family)
353					continue;
354				if (sa->sa_family == AF_INET) {
355					sin = (struct sockaddr_in *)sa;
356					sin2 = (struct sockaddr_in *)sa2;
357					if (memcmp(&sin->sin_addr,
358					    &sin2->sin_addr,
359					    sizeof(struct in_addr)))
360						continue;
361				} else {
362					sin6 = (struct sockaddr_in6 *)sa;
363					sin62 = (struct sockaddr_in6 *)sa2;
364					if (memcmp(&sin6->sin6_addr,
365					    &sin62->sin6_addr,
366					    sizeof(struct in6_addr)))
367						continue;
368				}
369				/* Match! */
370				found++;
371				p->socket = newsock;
372				log_msg(1, "peer \"%s\" connected", p->name);
373			}
374			if (!found) {
375				log_msg(1, "Found no matching peer for "
376				    "accepted socket, closing.");
377				close(newsock);
378			}
379		} else
380			log_err("accept()");
381	}
382
383	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
384		if (p->socket < 0 || !FD_ISSET(p->socket, fds))
385			continue;
386		msg = net_read(p, &msgtype, &msglen);
387		if (!msg)
388			continue;
389
390		/* XXX check message validity. */
391
392		log_msg(4, "net_handle_messages: got msg type %u len %u from "
393		    "peer %s", msgtype, msglen, p->name);
394
395		switch (msgtype) {
396		case MSG_SYNCCTL:
397			net_ctl_handle_msg(p, msg, msglen);
398			free(msg);
399			break;
400
401		case MSG_PFKEYDATA:
402			if (p->runstate != MASTER ||
403			    cfgstate.runstate == MASTER) {
404				log_msg(0, "got PFKEY message from non-MASTER "
405				    "peer");
406				free(msg);
407				if (cfgstate.runstate == MASTER)
408					net_ctl_send_state(p);
409				else
410					net_ctl_send_error(p, 0);
411			} else if (pfkey_queue_message(msg, msglen))
412				free(msg);
413			break;
414
415		default:
416			log_msg(0, "Got unknown message type %u len %u from "
417			    "peer %s", msgtype, msglen, p->name);
418			free(msg);
419			net_ctl_send_error(p, 0);
420		}
421	}
422}
423
424void
425net_send_messages(fd_set *fds)
426{
427	struct syncpeer *p;
428	struct qmsg	*qm;
429	struct msg	*m;
430	ssize_t		 r;
431
432	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
433		if (p->socket < 0 || !FD_ISSET(p->socket, fds))
434			continue;
435
436		qm = SIMPLEQ_FIRST(&p->msgs);
437		if (!qm) {
438			/* XXX Log */
439			continue;
440		}
441		m = qm->msg;
442
443		log_msg(4, "sending msg %p len %d ref %d to peer %s", m,
444		    m->len, m->refcnt, p->name);
445
446		/* write message */
447		r = write(p->socket, m->buf, m->len);
448		if (r == -1)
449			log_err("net_send_messages: write()");
450		else if (r < (ssize_t)m->len) {
451			/* XXX retransmit? */
452			continue;
453		}
454
455		/* cleanup */
456		SIMPLEQ_REMOVE_HEAD(&p->msgs, next);
457		free(qm);
458
459		if (--m->refcnt < 1) {
460			log_msg(4, "freeing msg %p", m);
461			free(m->buf);
462			free(m);
463		}
464	}
465	return;
466}
467
468void
469net_disconnect_peer(struct syncpeer *p)
470{
471	if (p->socket > -1)
472		close(p->socket);
473	p->socket = -1;
474}
475
476void
477net_shutdown(void)
478{
479	struct syncpeer *p;
480	struct qmsg	*qm;
481	struct msg	*m;
482
483	while ((p = LIST_FIRST(&cfgstate.peerlist))) {
484		while ((qm = SIMPLEQ_FIRST(&p->msgs))) {
485			SIMPLEQ_REMOVE_HEAD(&p->msgs, next);
486			m = qm->msg;
487			if (--m->refcnt < 1) {
488				free(m->buf);
489				free(m);
490			}
491			free(qm);
492		}
493		net_disconnect_peer(p);
494		if (p->name)
495			free(p->name);
496		LIST_REMOVE(p, link);
497		free(p);
498	}
499
500	if (listen_socket > -1)
501		close(listen_socket);
502}
503
504/*
505 * Helper functions (local) below here.
506 */
507
508static u_int8_t *
509net_read(struct syncpeer *p, u_int32_t *msgtype, u_int32_t *msglen)
510{
511	u_int8_t	*msg, *blob, *rhash, *iv, hash[SHA_DIGEST_LENGTH];
512	u_int32_t	 v, blob_len;
513	int		 padlen = 0, offset = 0, r;
514	SHA_CTX		 ctx;
515
516	/* Read blob length */
517	if (read(p->socket, &v, sizeof v) != (ssize_t)sizeof v)
518		return NULL;
519	blob_len = ntohl(v);
520	if (blob_len < sizeof hash + AES_IV_LEN + 2 * sizeof(u_int32_t))
521		return NULL;
522	*msglen = blob_len - sizeof hash - AES_IV_LEN - 2 * sizeof(u_int32_t);
523
524	/* Read message blob */
525	blob = (u_int8_t *)malloc(blob_len);
526	if (!blob) {
527		log_err("net_read: malloc()");
528		return NULL;
529	}
530	r = read(p->socket, blob, blob_len);
531	if (r == -1) {
532		free(blob);
533		return NULL;
534	} else if (r < (ssize_t)blob_len) {
535		/* XXX wait and read more? */
536		fprintf(stderr, "net_read: wanted %d, got %d\n", blob_len, r);
537		free(blob);
538		return NULL;
539	}
540
541	offset = 0;
542	memcpy(&v, blob + offset, sizeof v);
543	*msgtype = ntohl(v);
544	offset += sizeof v;
545
546	if (*msgtype > MSG_MAXTYPE) {
547		free(blob);
548		return NULL;
549	}
550
551	memcpy(&v, blob + offset, sizeof v);
552	padlen = ntohl(v);
553	offset += sizeof v;
554
555	rhash = blob + offset;
556	iv    = rhash + sizeof hash;
557	msg = (u_int8_t *)malloc(*msglen);
558	if (!msg) {
559		free(blob);
560		return NULL;
561	}
562	memcpy(msg, iv + AES_IV_LEN, *msglen);
563
564	dump_buf(5, rhash, sizeof hash, "Recv hash");
565	dump_buf(5, iv, sizeof iv, "Recv IV");
566	dump_buf(5, msg, *msglen, "Pre-decrypt");
567	AES_cbc_encrypt(msg, msg, *msglen, &aes_key[1], iv, AES_DECRYPT);
568	dump_buf(5, msg, *msglen, "Post-decrypt");
569	*msglen -= padlen;
570
571	SHA1_Init(&ctx);
572	SHA1_Update(&ctx, msg, *msglen);
573	SHA1_Final(hash, &ctx);
574	dump_buf(5, hash, sizeof hash, "Local hash");
575
576	if (memcmp(hash, rhash, sizeof hash) != 0) {
577		free(blob);
578		log_msg(0, "net_read: bad msg hash (shared key typo?)");
579		return NULL;
580	}
581	free(blob);
582	return msg;
583}
584
585static int
586net_set_sa(struct sockaddr *sa, char *name, in_port_t port)
587{
588	struct sockaddr_in	*sin = (struct sockaddr_in *)sa;
589	struct sockaddr_in6	*sin6 = (struct sockaddr_in6 *)sa;
590	struct ifaddrs		*ifap, *ifa;
591
592	if (!name) {
593		/* XXX Assume IPv4 */
594		sa->sa_family = AF_INET;
595		sin->sin_port = htons(port);
596		sin->sin_len = sizeof *sin;
597		return 0;
598	}
599
600	if (inet_pton(AF_INET, name, &sin->sin_addr) == 1) {
601		sa->sa_family = AF_INET;
602		sin->sin_port = htons(port);
603		sin->sin_len = sizeof *sin;
604		return 0;
605	}
606
607	if (inet_pton(AF_INET6, name, &sin6->sin6_addr) == 1) {
608		sa->sa_family = AF_INET6;
609		sin6->sin6_port = htons(port);
610		sin6->sin6_len = sizeof *sin6;
611		return 0;
612	}
613
614	/* inet_pton failed. fail here if name is not cfgstate.listen_on */
615	if (strcmp(cfgstate.listen_on, name) != 0)
616		return -1;
617
618	/* Is cfgstate.listen_on the name of one our interfaces? */
619	if (getifaddrs(&ifap) != 0) {
620		perror("getifaddrs()");
621		return -1;
622	}
623	sa->sa_family = AF_UNSPEC;
624	for (ifa = ifap; ifa && sa->sa_family == AF_UNSPEC;
625	     ifa = ifa->ifa_next) {
626		if (!ifa->ifa_name || !ifa->ifa_addr)
627			continue;
628		if (strcmp(ifa->ifa_name, name) != 0)
629			continue;
630
631		switch (ifa->ifa_addr->sa_family) {
632		case AF_INET:
633			sa->sa_family = AF_INET;
634			sin->sin_port = htons(port);
635			sin->sin_len = sizeof *sin;
636			memcpy(&sin->sin_addr,
637			    &((struct sockaddr_in *)ifa->ifa_addr)->sin_addr,
638			    sizeof sin->sin_addr);
639			break;
640
641		case AF_INET6:
642			sa->sa_family = AF_INET6;
643			sin6->sin6_port = htons(port);
644			sin6->sin6_len = sizeof *sin6;
645			memcpy(&sin6->sin6_addr,
646			    &((struct sockaddr_in6 *)ifa->ifa_addr)->sin6_addr,
647			    sizeof sin6->sin6_addr);
648			break;
649		}
650	}
651	freeifaddrs(ifap);
652	return sa->sa_family == AF_UNSPEC ? -1 : 0;
653}
654
655
656static void
657got_sigalrm(int s)
658{
659	return;
660}
661
662void
663net_connect_peers(void)
664{
665	struct sockaddr_storage sa_storage;
666	struct itimerval	iv;
667	struct sockaddr	*sa = (struct sockaddr *)&sa_storage;
668	struct syncpeer	*p;
669
670	signal(SIGALRM, got_sigalrm);
671	memset(&iv, 0, sizeof iv);
672	iv.it_value.tv_sec = 5;
673	iv.it_interval.tv_sec = 5;
674	setitimer(ITIMER_REAL, &iv, NULL);
675
676	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
677		if (p->socket > -1)
678			continue;
679
680		memset(sa, 0, sizeof sa_storage);
681		if (net_set_sa(sa, p->name, cfgstate.listen_port))
682			continue;
683		p->socket = socket(sa->sa_family, SOCK_STREAM, 0);
684		if (p->socket < 0) {
685			log_err("peer \"%s\": socket()", p->name);
686			continue;
687		}
688		if (connect(p->socket, sa, sa->sa_len)) {
689			log_msg(1, "peer \"%s\" not ready yet", p->name);
690			net_disconnect_peer(p);
691			continue;
692		}
693		if (net_ctl_send_state(p)) {
694			log_msg(0, "peer \"%s\" failed", p->name);
695			net_disconnect_peer(p);
696			continue;
697		}
698		log_msg(1, "peer \"%s\" connected", p->name);
699	}
700
701	timerclear(&iv.it_value);
702	timerclear(&iv.it_interval);
703	setitimer(ITIMER_REAL, &iv, NULL);
704	signal(SIGALRM, SIG_IGN);
705
706	return;
707}
708
709static void
710net_check_peers(void *arg)
711{
712	net_connect_peers();
713
714	(void)timer_add("peer recheck", 600, net_check_peers, 0);
715}
716
717