net.c revision 1.2
1/*	$OpenBSD: net.c,v 1.2 2005/05/22 20:35:48 ho Exp $	*/
2
3/*
4 * Copyright (c) 2005 H�kan Olsson.  All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28/*
29 * This code was written under funding by Multicom Security AB.
30 */
31
32
33#include <sys/types.h>
34#include <sys/socket.h>
35#include <sys/time.h>
36#include <netinet/in.h>
37#include <arpa/inet.h>
38
39#include <openssl/aes.h>
40#include <openssl/sha.h>
41
42#include <errno.h>
43#include <stdio.h>
44#include <stdlib.h>
45#include <string.h>
46#include <unistd.h>
47
48#include "sasyncd.h"
49#include "net.h"
50
51struct msg {
52	u_int8_t	*buf;
53	u_int32_t	 len;
54	int		 refcnt;
55};
56
57struct qmsg {
58	SIMPLEQ_ENTRY(qmsg)	next;
59	struct msg	*msg;
60};
61
62int	listen_socket;
63AES_KEY	aes_key[2];
64#define AES_IV_LEN AES_BLOCK_SIZE
65
66/* Local prototypes. */
67static u_int8_t *net_read(struct syncpeer *, u_int32_t *, u_int32_t *);
68static int	 net_set_sa(struct sockaddr *, char *, in_port_t);
69static void	 net_check_peers(void *);
70
71static void
72dump_buf(int lvl, u_int8_t *b, u_int32_t len, char *title)
73{
74	u_int32_t	i, off, blen = len*2 + 3 + strlen(title);
75	u_int8_t	*buf = calloc(1, blen);
76
77	if (!buf || cfgstate.verboselevel < lvl)
78		return;
79
80	snprintf(buf, blen, "%s:\n", title);
81	off = strlen(buf);
82	for (i = 0; i < len; i++, off+=2)
83		snprintf(buf + off, blen - off, "%02x", b[i]);
84	log_msg(lvl, "%s", buf);
85	free(buf);
86}
87
88int
89net_init(void)
90{
91	struct sockaddr_storage sa_storage;
92	struct sockaddr *sa = (struct sockaddr *)&sa_storage;
93	struct syncpeer *p;
94	int		 r;
95
96	/* The shared key needs to be 128, 192 or 256 bits */
97	r = (strlen(cfgstate.sharedkey) - 1) << 3;
98	if (r != 128 && r != 192 && r != 256) {
99		fprintf(stderr, "Bad shared key length (%d bits), "
100		    "should be 128, 192 or 256\n", r);
101		return -1;
102	}
103
104	if (AES_set_encrypt_key(cfgstate.sharedkey, r, &aes_key[0]) ||
105	    AES_set_decrypt_key(cfgstate.sharedkey, r, &aes_key[1])) {
106		fprintf(stderr, "Bad AES shared key\n");
107		return -1;
108	}
109
110	/* Setup listening socket.  */
111	memset(&sa_storage, 0, sizeof sa_storage);
112	if (net_set_sa(sa, cfgstate.listen_on, cfgstate.listen_port)) {
113		perror("inet_pton");
114		return -1;
115	}
116	listen_socket = socket(sa->sa_family, SOCK_STREAM, 0);
117	if (listen_socket < 0) {
118		perror("socket()");
119		close(listen_socket);
120		return -1;
121	}
122	r = 1;
123	if (setsockopt(listen_socket, SOL_SOCKET,
124	    cfgstate.listen_on ? SO_REUSEADDR : SO_REUSEPORT, (void *)&r,
125	    sizeof r)) {
126		perror("setsockopt()");
127		close(listen_socket);
128		return -1;
129	}
130	if (bind(listen_socket, sa, sizeof(struct sockaddr_in))) {
131		perror("bind()");
132		close(listen_socket);
133		return -1;
134	}
135	if (listen(listen_socket, 10)) {
136		perror("listen()");
137		close(listen_socket);
138		return -1;
139	}
140	log_msg(2, "listening on port %u fd %d", cfgstate.listen_port,
141	    listen_socket);
142
143	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
144		p->socket = -1;
145		SIMPLEQ_INIT(&p->msgs);
146	}
147
148	net_check_peers(0);
149	return 0;
150}
151
152static void
153net_enqueue(struct syncpeer *p, struct msg *m)
154{
155	struct qmsg	*qm;
156
157	if (p->socket < 0)
158		return;
159
160	qm = (struct qmsg *)malloc(sizeof *qm);
161	if (!qm) {
162		log_err("malloc()");
163		return;
164	}
165
166	memset(qm, 0, sizeof *qm);
167	qm->msg = m;
168	m->refcnt++;
169
170	SIMPLEQ_INSERT_TAIL(&p->msgs, qm, next);
171	return;
172}
173
174/*
175 * Queue a message for transmission to a particular peer,
176 * or to all peers if no peer is specified.
177 */
178int
179net_queue(struct syncpeer *p0, u_int32_t msgtype, u_int8_t *buf, u_int32_t len)
180{
181	struct syncpeer *p = p0;
182	struct msg	*m;
183	SHA_CTX		 ctx;
184	u_int8_t	 hash[SHA_DIGEST_LENGTH];
185	u_int8_t	 iv[AES_IV_LEN], tmp_iv[AES_IV_LEN];
186	u_int32_t	 v, padlen = 0;
187	int		 i, offset;
188
189	m = (struct msg *)calloc(1, sizeof *m);
190	if (!m) {
191		log_err("calloc()");
192		free(buf);
193		return -1;
194	}
195
196	/* Generate hash */
197	SHA1_Init(&ctx);
198	SHA1_Update(&ctx, buf, len);
199	SHA1_Final(hash, &ctx);
200	dump_buf(5, hash, sizeof hash, "Hash");
201
202	/* Padding required? */
203	i = len % AES_IV_LEN;
204	if (i) {
205		u_int8_t *pbuf;
206		i = AES_IV_LEN - i;
207		pbuf = realloc(buf, len + i);
208		if (!pbuf) {
209			log_err("net_queue: realloc()");
210			free(buf);
211			free(m);
212			return -1;
213		}
214		padlen = i;
215		while (i > 0)
216			pbuf[len++] = (u_int8_t)i--;
217		buf = pbuf;
218	}
219
220	/* Get random IV */
221	for (i = 0; i <= sizeof iv - sizeof v; i += sizeof v) {
222		v = arc4random();
223		memcpy(&iv[i], &v, sizeof v);
224	}
225	dump_buf(5, iv, sizeof iv, "IV");
226	memcpy(tmp_iv, iv, sizeof tmp_iv);
227
228	/* Encrypt */
229	dump_buf(5, buf, len, "Pre-enc");
230	AES_cbc_encrypt(buf, buf, len, &aes_key[0], tmp_iv, AES_ENCRYPT);
231	dump_buf(5, buf, len, "Post-enc");
232
233	/* Allocate send buffer */
234	m->len = len + sizeof iv + sizeof hash + 3 * sizeof(u_int32_t);
235	m->buf = (u_int8_t *)malloc(m->len);
236	if (!m->buf) {
237		free(m);
238		free(buf);
239		log_err("net_queue: calloc()");
240		return -1;
241	}
242	offset = 0;
243
244	/* Fill it (order must match parsing code in net_read()) */
245	v = htonl(m->len - sizeof(u_int32_t));
246	memcpy(m->buf + offset, &v, sizeof v);
247	offset += sizeof v;
248	v = htonl(msgtype);
249	memcpy(m->buf + offset, &v, sizeof v);
250	offset += sizeof v;
251	v = htonl(padlen);
252	memcpy(m->buf + offset, &v, sizeof v);
253	offset += sizeof v;
254	memcpy(m->buf + offset, hash, sizeof hash);
255	offset += sizeof hash;
256	memcpy(m->buf + offset, iv, sizeof iv);
257	offset += sizeof iv;
258	memcpy(m->buf + offset, buf, len);
259	free(buf);
260
261	if (p)
262		net_enqueue(p, m);
263	else
264		for (p = LIST_FIRST(&cfgstate.peerlist); p;
265		     p = LIST_NEXT(p, link))
266			net_enqueue(p, m);
267
268	if (!m->refcnt) {
269		free(m->buf);
270		free(m);
271	}
272
273	return 0;
274}
275
276/* Set all write pending filedescriptors. */
277int
278net_set_pending_wfds(fd_set *fds)
279{
280	struct syncpeer *p;
281	int		max_fd = -1;
282
283	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link))
284		if (p->socket > -1 && SIMPLEQ_FIRST(&p->msgs)) {
285			FD_SET(p->socket, fds);
286			if (p->socket > max_fd)
287				max_fd = p->socket;
288		}
289	return max_fd + 1;
290}
291
292/*
293 * Set readable filedescriptors. They are basically the same as for write,
294 * plus the listening socket.
295 */
296int
297net_set_rfds(fd_set *fds)
298{
299	struct syncpeer *p;
300	int		max_fd = -1;
301
302	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
303		if (p->socket > -1)
304			FD_SET(p->socket, fds);
305		if (p->socket > max_fd)
306			max_fd = p->socket;
307	}
308	FD_SET(listen_socket, fds);
309	if (listen_socket > max_fd)
310		max_fd = listen_socket;
311	return max_fd + 1;
312}
313
314void
315net_handle_messages(fd_set *fds)
316{
317	struct sockaddr_storage	sa_storage, sa_storage2;
318	struct sockaddr	*sa = (struct sockaddr *)&sa_storage;
319	struct sockaddr	*sa2 = (struct sockaddr *)&sa_storage2;
320	socklen_t	socklen;
321	struct syncpeer *p;
322	u_int8_t	*msg;
323	u_int32_t	 msgtype, msglen;
324	int		 newsock, found;
325
326	if (FD_ISSET(listen_socket, fds)) {
327		/* Accept a new incoming connection */
328		socklen = sizeof sa_storage;
329		newsock = accept(listen_socket, sa, &socklen);
330		if (newsock > -1) {
331			/* Setup the syncpeer structure */
332			found = 0;
333			for (p = LIST_FIRST(&cfgstate.peerlist); p && !found;
334			     p = LIST_NEXT(p, link)) {
335				struct sockaddr_in *sin, *sin2;
336				struct sockaddr_in6 *sin6, *sin62;
337
338				/* Match? */
339				if (net_set_sa(sa2, p->name, 0))
340					continue;
341				if (sa->sa_family != sa2->sa_family)
342					continue;
343				if (sa->sa_family == AF_INET) {
344					sin = (struct sockaddr_in *)sa;
345					sin2 = (struct sockaddr_in *)sa2;
346					if (memcmp(&sin->sin_addr,
347					    &sin2->sin_addr,
348					    sizeof(struct in_addr)))
349						continue;
350				} else {
351					sin6 = (struct sockaddr_in6 *)sa;
352					sin62 = (struct sockaddr_in6 *)sa2;
353					if (memcmp(&sin6->sin6_addr,
354					    &sin62->sin6_addr,
355					    sizeof(struct in6_addr)))
356						continue;
357				}
358				/* Match! */
359				found++;
360				p->socket = newsock;
361				log_msg(1, "peer \"%s\" connected", p->name);
362			}
363			if (!found) {
364				log_msg(1, "Found no matching peer for "
365				    "accepted socket, closing.");
366				close(newsock);
367			}
368		} else
369			log_err("accept()");
370	}
371
372	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
373		if (p->socket < 0 || !FD_ISSET(p->socket, fds))
374			continue;
375		msg = net_read(p, &msgtype, &msglen);
376		if (!msg)
377			continue;
378
379		/* XXX check message validity. */
380
381		log_msg(4, "net_handle_messages: got msg type %u len %u from "
382		    "peer %s", msgtype, msglen, p->name);
383
384		switch (msgtype) {
385		case MSG_SYNCCTL:
386			net_ctl_handle_msg(p, msg, msglen);
387			free(msg);
388			break;
389
390		case MSG_PFKEYDATA:
391			if (p->runstate != MASTER ||
392			    cfgstate.runstate == MASTER) {
393				log_msg(0, "got PFKEY message from non-MASTER "
394				    "peer");
395				free(msg);
396				if (cfgstate.runstate == MASTER)
397					net_ctl_send_state(p);
398				else
399					net_ctl_send_error(p, 0);
400			} else if (pfkey_queue_message(msg, msglen))
401				free(msg);
402			break;
403
404		default:
405			log_msg(0, "Got unknown message type %u len %u from "
406			    "peer %s", msgtype, msglen, p->name);
407			free(msg);
408			net_ctl_send_error(p, 0);
409		}
410	}
411}
412
413void
414net_send_messages(fd_set *fds)
415{
416	struct syncpeer *p;
417	struct qmsg	*qm;
418	struct msg	*m;
419	ssize_t		 r;
420
421	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
422		if (p->socket < 0 || !FD_ISSET(p->socket, fds))
423			continue;
424
425		qm = SIMPLEQ_FIRST(&p->msgs);
426		if (!qm) {
427			/* XXX Log */
428			continue;
429		}
430		m = qm->msg;
431
432		log_msg(4, "sending msg %p len %d ref %d to peer %s", m,
433		    m->len, m->refcnt, p->name);
434
435		/* write message */
436		r = write(p->socket, m->buf, m->len);
437		if (r == -1)
438			log_err("net_send_messages: write()");
439		else if (r < (ssize_t)m->len) {
440			/* XXX retransmit? */
441			continue;
442		}
443
444		/* cleanup */
445		SIMPLEQ_REMOVE_HEAD(&p->msgs, next);
446		free(qm);
447
448		if (--m->refcnt < 1) {
449			log_msg(4, "freeing msg %p", m);
450			free(m->buf);
451			free(m);
452		}
453	}
454	return;
455}
456
457void
458net_disconnect_peer(struct syncpeer *p)
459{
460	if (p->socket > -1)
461		close(p->socket);
462	p->socket = -1;
463}
464
465void
466net_shutdown(void)
467{
468	struct syncpeer *p;
469	struct qmsg	*qm;
470	struct msg	*m;
471
472	while ((p = LIST_FIRST(&cfgstate.peerlist))) {
473		while ((qm = SIMPLEQ_FIRST(&p->msgs))) {
474			SIMPLEQ_REMOVE_HEAD(&p->msgs, next);
475			m = qm->msg;
476			if (--m->refcnt < 1) {
477				free(m->buf);
478				free(m);
479			}
480			free(qm);
481		}
482		net_disconnect_peer(p);
483		if (p->name)
484			free(p->name);
485		LIST_REMOVE(p, link);
486		free(p);
487	}
488
489	if (listen_socket > -1)
490		close(listen_socket);
491}
492
493/*
494 * Helper functions (local) below here.
495 */
496
497static u_int8_t *
498net_read(struct syncpeer *p, u_int32_t *msgtype, u_int32_t *msglen)
499{
500	u_int8_t	*msg, *blob, *rhash, *iv, hash[SHA_DIGEST_LENGTH];
501	u_int32_t	 v, blob_len;
502	int		 padlen = 0, offset = 0, r;
503	SHA_CTX		 ctx;
504
505	/* Read blob length */
506	if (read(p->socket, &v, sizeof v) != (ssize_t)sizeof v)
507		return NULL;
508	blob_len = ntohl(v);
509	if (blob_len < sizeof hash + AES_IV_LEN + 2 * sizeof(u_int32_t))
510		return NULL;
511	*msglen = blob_len - sizeof hash - AES_IV_LEN - 2 * sizeof(u_int32_t);
512
513	/* Read message blob */
514	blob = (u_int8_t *)malloc(blob_len);
515	if (!blob) {
516		log_err("net_read: malloc()");
517		return NULL;
518	}
519	r = read(p->socket, blob, blob_len);
520	if (r == -1) {
521		free(blob);
522		return NULL;
523	} else if (r < (ssize_t)blob_len) {
524		/* XXX wait and read more? */
525		fprintf(stderr, "net_read: wanted %d, got %d\n", blob_len, r);
526		free(blob);
527		return NULL;
528	}
529
530	offset = 0;
531	memcpy(&v, blob + offset, sizeof v);
532	*msgtype = ntohl(v);
533	offset += sizeof v;
534
535	if (*msgtype > MSG_MAXTYPE) {
536		free(blob);
537		return NULL;
538	}
539
540	memcpy(&v, blob + offset, sizeof v);
541	padlen = ntohl(v);
542	offset += sizeof v;
543
544	rhash = blob + offset;
545	iv    = rhash + sizeof hash;
546	msg = (u_int8_t *)malloc(*msglen);
547	if (!msg) {
548		free(blob);
549		return NULL;
550	}
551	memcpy(msg, iv + AES_IV_LEN, *msglen);
552
553	dump_buf(5, rhash, sizeof hash, "Recv hash");
554	dump_buf(5, iv, sizeof iv, "Recv IV");
555	dump_buf(5, msg, *msglen, "Pre-decrypt");
556	AES_cbc_encrypt(msg, msg, *msglen, &aes_key[1], iv, AES_DECRYPT);
557	dump_buf(5, msg, *msglen, "Post-decrypt");
558	*msglen -= padlen;
559
560	SHA1_Init(&ctx);
561	SHA1_Update(&ctx, msg, *msglen);
562	SHA1_Final(hash, &ctx);
563	dump_buf(5, hash, sizeof hash, "Local hash");
564
565	if (memcmp(hash, rhash, sizeof hash) != 0) {
566		free(blob);
567		log_msg(0, "net_read: bad msg hash (shared key typo?)");
568		return NULL;
569	}
570	free(blob);
571	return msg;
572}
573
574static int
575net_set_sa(struct sockaddr *sa, char *name, in_port_t port)
576{
577	struct sockaddr_in	*sin = (struct sockaddr_in *)sa;
578	struct sockaddr_in6	*sin6 = (struct sockaddr_in6 *)sa;
579
580	if (name) {
581		if (inet_pton(AF_INET, name, &sin->sin_addr) == 1) {
582			sa->sa_family = AF_INET;
583			sin->sin_port = htons(port);
584			sin->sin_len = sizeof *sin;
585			return 0;
586		}
587
588		if (inet_pton(AF_INET6, name, &sin6->sin6_addr) == 1) {
589			sa->sa_family = AF_INET6;
590			sin6->sin6_port = htons(port);
591			sin6->sin6_len = sizeof *sin6;
592			return 0;
593		}
594	} else {
595		/* XXX Assume IPv4 */
596		sa->sa_family = AF_INET;
597		sin->sin_port = htons(port);
598		sin->sin_len = sizeof *sin;
599		return 0;
600	}
601
602	return 1;
603}
604
605static void
606got_sigalrm(int s)
607{
608	return;
609}
610
611void
612net_connect_peers(void)
613{
614	struct sockaddr_storage sa_storage;
615	struct itimerval	iv;
616	struct sockaddr	*sa = (struct sockaddr *)&sa_storage;
617	struct syncpeer	*p;
618
619	signal(SIGALRM, got_sigalrm);
620	memset(&iv, 0, sizeof iv);
621	iv.it_value.tv_sec = 5;
622	iv.it_interval.tv_sec = 5;
623	setitimer(ITIMER_REAL, &iv, NULL);
624
625	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
626		if (p->socket > -1)
627			continue;
628
629		memset(sa, 0, sizeof sa_storage);
630		if (net_set_sa(sa, p->name, cfgstate.listen_port))
631			continue;
632		p->socket = socket(sa->sa_family, SOCK_STREAM, 0);
633		if (p->socket < 0) {
634			log_err("peer \"%s\": socket()", p->name);
635			continue;
636		}
637		if (connect(p->socket, sa, sa->sa_len)) {
638			log_msg(1, "peer \"%s\" not ready yet", p->name);
639			net_disconnect_peer(p);
640			continue;
641		}
642		if (net_ctl_send_state(p)) {
643			log_msg(0, "peer \"%s\" failed", p->name);
644			net_disconnect_peer(p);
645			continue;
646		}
647		log_msg(1, "peer \"%s\" connected", p->name);
648	}
649
650	timerclear(&iv.it_value);
651	timerclear(&iv.it_interval);
652	setitimer(ITIMER_REAL, &iv, NULL);
653	signal(SIGALRM, SIG_IGN);
654
655	return;
656}
657
658static void
659net_check_peers(void *arg)
660{
661	net_connect_peers();
662
663	(void)timer_add("peer recheck", 600, net_check_peers, 0);
664}
665
666