net.c revision 1.1
1/*	$OpenBSD: net.c,v 1.1 2005/03/30 18:44:49 ho Exp $	*/
2
3/*
4 * Copyright (c) 2005 H�kan Olsson.  All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28/*
29 * This code was written under funding by Multicom Security AB.
30 */
31
32
33#include <sys/types.h>
34#include <sys/socket.h>
35#include <sys/time.h>
36#include <netinet/in.h>
37#include <arpa/inet.h>
38
39#include <errno.h>
40#include <string.h>
41#include <unistd.h>
42
43#include "sasyncd.h"
44#include "net.h"
45
46struct msg {
47	u_int8_t	*buf;
48	u_int8_t	*obuf;		/* Original buf w/o offset. */
49	u_int32_t	 len;
50	u_int32_t	 type;
51	int		 refcnt;
52};
53
54struct qmsg {
55	SIMPLEQ_ENTRY(qmsg)	next;
56	struct msg	*msg;
57};
58
59int	listen_socket;
60
61/* Local prototypes. */
62static u_int8_t *net_read(struct syncpeer *, u_int32_t *, u_int32_t *);
63static int	 net_set_sa(struct sockaddr *, char *, in_port_t);
64static void	 net_check_peers(void *);
65
66int
67net_init(void)
68{
69	struct sockaddr_storage sa_storage;
70	struct sockaddr *sa = (struct sockaddr *)&sa_storage;
71	struct syncpeer *p;
72	int		 r;
73
74	if (net_SSL_init())
75		return -1;
76
77	/* Setup listening socket.  */
78	memset(&sa_storage, 0, sizeof sa_storage);
79	if (net_set_sa(sa, cfgstate.listen_on, cfgstate.listen_port)) {
80		perror("inet_pton");
81		return -1;
82	}
83	listen_socket = socket(sa->sa_family, SOCK_STREAM, 0);
84	if (listen_socket < 0) {
85		perror("socket()");
86		close(listen_socket);
87		return -1;
88	}
89	r = 1;
90	if (setsockopt(listen_socket, SOL_SOCKET,
91	    cfgstate.listen_on ? SO_REUSEADDR : SO_REUSEPORT, (void *)&r,
92	    sizeof r)) {
93		perror("setsockopt()");
94		close(listen_socket);
95		return -1;
96	}
97	if (bind(listen_socket, sa, sizeof(struct sockaddr_in))) {
98		perror("bind()");
99		close(listen_socket);
100		return -1;
101	}
102	if (listen(listen_socket, 10)) {
103		perror("listen()");
104		close(listen_socket);
105		return -1;
106	}
107	log_msg(2, "listening on port %u fd %d", cfgstate.listen_port,
108	    listen_socket);
109
110	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
111		p->socket = -1;
112		SIMPLEQ_INIT(&p->msgs);
113	}
114
115	net_check_peers(0);
116	return 0;
117}
118
119static void
120net_enqueue(struct syncpeer *p, struct msg *m)
121{
122	struct qmsg	*qm;
123
124	if (p->socket < 0)
125		return;
126
127	if (!p->ssl)
128		if (net_SSL_connect(p))
129			return;
130
131	qm = (struct qmsg *)malloc(sizeof *qm);
132	if (!qm) {
133		log_err("malloc()");
134		return;
135	}
136
137	memset(qm, 0, sizeof *qm);
138	qm->msg = m;
139	m->refcnt++;
140
141	SIMPLEQ_INSERT_TAIL(&p->msgs, qm, next);
142	return;
143}
144
145/*
146 * Queue a message for transmission to a particular peer,
147 * or to all peers if no peer is specified.
148 */
149int
150net_queue(struct syncpeer *p0, u_int32_t msgtype, u_int8_t *buf,
151    u_int32_t offset, u_int32_t len)
152{
153	struct syncpeer *p = p0;
154	struct msg	*m;
155
156	m = (struct msg *)malloc(sizeof *m);
157	if (!m) {
158		log_err("malloc()");
159		free(buf);
160		return -1;
161	}
162	memset(m, 0, sizeof *m);
163	m->obuf = buf;
164	m->buf = buf + offset;
165	m->len = len;
166	m->type = msgtype;
167
168	if (p)
169		net_enqueue(p, m);
170	else
171		for (p = LIST_FIRST(&cfgstate.peerlist); p;
172		     p = LIST_NEXT(p, link))
173			net_enqueue(p, m);
174
175	if (!m->refcnt) {
176		free(m->obuf);
177		free(m);
178	}
179
180	return 0;
181}
182
183/* Set all write pending filedescriptors. */
184int
185net_set_pending_wfds(fd_set *fds)
186{
187	struct syncpeer *p;
188	int		max_fd = -1;
189
190	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link))
191		if (p->socket > -1 && SIMPLEQ_FIRST(&p->msgs)) {
192			FD_SET(p->socket, fds);
193			if (p->socket > max_fd)
194				max_fd = p->socket;
195		}
196	return max_fd + 1;
197}
198
199/*
200 * Set readable filedescriptors. They are basically the same as for write,
201 * plus the listening socket.
202 */
203int
204net_set_rfds(fd_set *fds)
205{
206	struct syncpeer *p;
207	int		max_fd = -1;
208
209	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
210		if (p->socket > -1)
211			FD_SET(p->socket, fds);
212		if (p->socket > max_fd)
213			max_fd = p->socket;
214	}
215	FD_SET(listen_socket, fds);
216	if (listen_socket > max_fd)
217		max_fd = listen_socket;
218	return max_fd + 1;
219}
220
221void
222net_handle_messages(fd_set *fds)
223{
224	struct sockaddr_storage	sa_storage, sa_storage2;
225	struct sockaddr	*sa = (struct sockaddr *)&sa_storage;
226	struct sockaddr	*sa2 = (struct sockaddr *)&sa_storage2;
227	socklen_t	socklen;
228	struct syncpeer *p;
229	u_int8_t	*msg;
230	u_int32_t	 msgtype, msglen;
231	int		 newsock, found;
232
233	if (FD_ISSET(listen_socket, fds)) {
234		/* Accept a new incoming connection */
235		socklen = sizeof sa_storage;
236		newsock = accept(listen_socket, sa, &socklen);
237		if (newsock > -1) {
238			/* Setup the syncpeer structure */
239			found = 0;
240			for (p = LIST_FIRST(&cfgstate.peerlist); p && !found;
241			     p = LIST_NEXT(p, link)) {
242				struct sockaddr_in *sin, *sin2;
243				struct sockaddr_in6 *sin6, *sin62;
244
245				/* Match? */
246				if (net_set_sa(sa2, p->name, 0))
247					continue;
248				if (sa->sa_family != sa2->sa_family)
249					continue;
250				if (sa->sa_family == AF_INET) {
251					sin = (struct sockaddr_in *)sa;
252					sin2 = (struct sockaddr_in *)sa2;
253					if (memcmp(&sin->sin_addr,
254					    &sin2->sin_addr,
255					    sizeof(struct in_addr)))
256						continue;
257				} else {
258					sin6 = (struct sockaddr_in6 *)sa;
259					sin62 = (struct sockaddr_in6 *)sa2;
260					if (memcmp(&sin6->sin6_addr,
261					    &sin62->sin6_addr,
262					    sizeof(struct in6_addr)))
263						continue;
264				}
265				/* Match! */
266				found++;
267				p->socket = newsock;
268				p->ssl = NULL;
269				log_msg(1, "peer \"%s\" connected", p->name);
270			}
271			if (!found) {
272				log_msg(1, "Found no matching peer for "
273				    "accepted socket, closing.");
274				close(newsock);
275			}
276		} else
277			log_err("accept()");
278	}
279
280	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
281		if (p->socket < 0 || !FD_ISSET(p->socket, fds))
282			continue;
283		msg = net_read(p, &msgtype, &msglen);
284		if (!msg)
285			continue;
286
287		/* XXX check message validity. */
288
289		log_msg(4, "net_handle_messages: got msg type %u len %u from "
290		    "peer %s", msgtype, msglen, p->name);
291
292		switch (msgtype) {
293		case MSG_SYNCCTL:
294			net_ctl_handle_msg(p, msg, msglen);
295			free(msg);
296			break;
297
298		case MSG_PFKEYDATA:
299			if (p->runstate != MASTER ||
300			    cfgstate.runstate == MASTER) {
301				log_msg(0, "got PFKEY message from non-MASTER "
302				    "peer");
303				free(msg);
304				if (cfgstate.runstate == MASTER)
305					net_ctl_send_state(p);
306				else
307					net_ctl_send_error(p, 0);
308			} else if (pfkey_queue_message(msg, msglen))
309				free(msg);
310			break;
311
312		default:
313			log_msg(0, "Got unknown message type %u len %u from "
314			    "peer %s", msgtype, msglen, p->name);
315			free(msg);
316			net_ctl_send_error(p, 0);
317		}
318	}
319}
320
321void
322net_send_messages(fd_set *fds)
323{
324	struct syncpeer *p;
325	struct qmsg	*qm;
326	struct msg	*m;
327	u_int32_t	 v;
328
329	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
330		if (p->socket < 0 || !FD_ISSET(p->socket, fds))
331			continue;
332
333		qm = SIMPLEQ_FIRST(&p->msgs);
334		if (!qm) {
335			/* XXX Log */
336			continue;
337		}
338		m = qm->msg;
339
340		log_msg(4, "sending msg %p (qm %p ref %d) to peer %s", m, qm,
341		    m->refcnt, p->name);
342
343		/* Send the message. */
344		v = htonl(m->type);
345		if (net_SSL_write(p, &v, sizeof v))
346			continue;
347
348		v = htonl(m->len);
349		if (net_SSL_write(p, &v, sizeof v))
350			continue;
351
352		(void)net_SSL_write(p, m->buf, m->len);
353
354		/* Cleanup. */
355		SIMPLEQ_REMOVE_HEAD(&p->msgs, next);
356		free(qm);
357
358		if (--m->refcnt < 1) {
359			log_msg(4, "freeing msg %p", m);
360			free(m->obuf);
361			free(m);
362		}
363	}
364	return;
365}
366
367void
368net_disconnect_peer(struct syncpeer *p)
369{
370	net_SSL_disconnect(p);
371	if (p->socket > -1)
372		close(p->socket);
373	p->socket = -1;
374}
375
376void
377net_shutdown(void)
378{
379	struct syncpeer *p;
380	struct qmsg	*qm;
381	struct msg	*m;
382
383	while ((p = LIST_FIRST(&cfgstate.peerlist))) {
384		while ((qm = SIMPLEQ_FIRST(&p->msgs))) {
385			SIMPLEQ_REMOVE_HEAD(&p->msgs, next);
386			m = qm->msg;
387			if (--m->refcnt < 1) {
388				free(m->obuf);
389				free(m);
390			}
391			free(qm);
392		}
393		net_disconnect_peer(p);
394		if (p->name)
395			free(p->name);
396		LIST_REMOVE(p, link);
397		free(p);
398	}
399
400	if (listen_socket > -1)
401		close(listen_socket);
402	net_SSL_shutdown();
403}
404
405/*
406 * Helper functions (local) below here.
407 */
408
409static u_int8_t *
410net_read(struct syncpeer *p, u_int32_t *msgtype, u_int32_t *msglen)
411{
412	u_int8_t	*msg;
413	u_int32_t	 v;
414
415	if (net_SSL_read(p, &v, sizeof v))
416		return NULL;
417	*msgtype = ntohl(v);
418
419	if (*msgtype > MSG_MAXTYPE)
420		return NULL;
421
422	if (net_SSL_read(p, &v, sizeof v))
423		return NULL;
424	*msglen = ntohl(v);
425
426	/* XXX msglen sanity */
427
428	msg = (u_int8_t *)malloc(*msglen);
429	memset(msg, 0, *msglen);
430	if (net_SSL_read(p, msg, *msglen)) {
431		free(msg);
432		return NULL;
433	}
434
435	return msg;
436}
437
438static int
439net_set_sa(struct sockaddr *sa, char *name, in_port_t port)
440{
441	struct sockaddr_in	*sin = (struct sockaddr_in *)sa;
442	struct sockaddr_in6	*sin6 = (struct sockaddr_in6 *)sa;
443
444	if (name) {
445		if (inet_pton(AF_INET, name, &sin->sin_addr) == 1) {
446			sa->sa_family = AF_INET;
447			sin->sin_port = htons(port);
448			sin->sin_len = sizeof *sin;
449			return 0;
450		}
451
452		if (inet_pton(AF_INET6, name, &sin6->sin6_addr) == 1) {
453			sa->sa_family = AF_INET6;
454			sin6->sin6_port = htons(port);
455			sin6->sin6_len = sizeof *sin6;
456			return 0;
457		}
458	} else {
459		/* XXX Assume IPv4 */
460		sa->sa_family = AF_INET;
461		sin->sin_port = htons(port);
462		sin->sin_len = sizeof *sin;
463		return 0;
464	}
465
466	return 1;
467}
468
469static void
470got_sigalrm(int s)
471{
472	return;
473}
474
475void
476net_connect_peers(void)
477{
478	struct sockaddr_storage sa_storage;
479	struct itimerval	iv;
480	struct sockaddr	*sa = (struct sockaddr *)&sa_storage;
481	struct syncpeer	*p;
482
483	signal(SIGALRM, got_sigalrm);
484	memset(&iv, 0, sizeof iv);
485	iv.it_value.tv_sec = 5;
486	iv.it_interval.tv_sec = 5;
487	setitimer(ITIMER_REAL, &iv, NULL);
488
489	for (p = LIST_FIRST(&cfgstate.peerlist); p; p = LIST_NEXT(p, link)) {
490		if (p->ssl || p->socket > -1)
491			continue;
492
493		memset(sa, 0, sizeof sa_storage);
494		if (net_set_sa(sa, p->name, cfgstate.listen_port))
495			continue;
496		p->socket = socket(sa->sa_family, SOCK_STREAM, 0);
497		if (p->socket < 0) {
498			log_err("peer \"%s\": socket()", p->name);
499			continue;
500		}
501		if (connect(p->socket, sa, sa->sa_len)) {
502			log_msg(1, "peer \"%s\" not ready yet", p->name);
503			net_disconnect_peer(p);
504			continue;
505		}
506		if (net_ctl_send_state(p)) {
507			log_msg(0, "peer \"%s\" failed", p->name);
508			net_disconnect_peer(p);
509			continue;
510		}
511		log_msg(1, "peer \"%s\" connected", p->name);
512	}
513
514	timerclear(&iv.it_value);
515	timerclear(&iv.it_interval);
516	setitimer(ITIMER_REAL, &iv, NULL);
517	signal(SIGALRM, SIG_IGN);
518
519	return;
520}
521
522static void
523net_check_peers(void *arg)
524{
525	net_connect_peers();
526
527	(void)timer_add("peer recheck", 600, net_check_peers, 0);
528}
529
530