1/*	$NetBSD: bl.c,v 1.28 2016/07/29 17:13:09 christos Exp $	*/
2
3/*-
4 * Copyright (c) 2014 The NetBSD Foundation, Inc.
5 * All rights reserved.
6 *
7 * This code is derived from software contributed to The NetBSD Foundation
8 * by Christos Zoulas.
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
20 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
21 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
22 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
23 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 * POSSIBILITY OF SUCH DAMAGE.
30 */
31#ifdef HAVE_CONFIG_H
32#include "config.h"
33#endif
34
35#include <sys/cdefs.h>
36__RCSID("$NetBSD: bl.c,v 1.28 2016/07/29 17:13:09 christos Exp $");
37
38#include <sys/param.h>
39#include <sys/types.h>
40#include <sys/socket.h>
41#include <sys/stat.h>
42#include <sys/un.h>
43
44#include <stdio.h>
45#include <string.h>
46#include <syslog.h>
47#include <signal.h>
48#include <fcntl.h>
49#include <stdlib.h>
50#include <unistd.h>
51#include <stdint.h>
52#include <stdbool.h>
53#include <errno.h>
54#include <stdarg.h>
55#include <netinet/in.h>
56#ifdef _REENTRANT
57#include <pthread.h>
58#endif
59
60#include "bl.h"
61
62typedef struct {
63	uint32_t bl_len;
64	uint32_t bl_version;
65	uint32_t bl_type;
66	uint32_t bl_salen;
67	struct sockaddr_storage bl_ss;
68	char bl_data[];
69} bl_message_t;
70
71struct blacklist {
72#ifdef _REENTRANT
73	pthread_mutex_t b_mutex;
74# define BL_INIT(b)	pthread_mutex_init(&b->b_mutex, NULL)
75# define BL_LOCK(b)	pthread_mutex_lock(&b->b_mutex)
76# define BL_UNLOCK(b)	pthread_mutex_unlock(&b->b_mutex)
77#else
78# define BL_INIT(b)	do {} while(/*CONSTCOND*/0)
79# define BL_LOCK(b)	BL_INIT(b)
80# define BL_UNLOCK(b)	BL_INIT(b)
81#endif
82	int b_fd;
83	int b_connected;
84	struct sockaddr_un b_sun;
85	void (*b_fun)(int, const char *, va_list);
86	bl_info_t b_info;
87};
88
89#define BL_VERSION	1
90
91bool
92bl_isconnected(bl_t b)
93{
94	return b->b_connected == 0;
95}
96
97int
98bl_getfd(bl_t b)
99{
100	return b->b_fd;
101}
102
103static void
104bl_reset(bl_t b, bool locked)
105{
106	int serrno = errno;
107	if (!locked)
108		BL_LOCK(b);
109	close(b->b_fd);
110	errno = serrno;
111	b->b_fd = -1;
112	b->b_connected = -1;
113	if (!locked)
114		BL_UNLOCK(b);
115}
116
117static void
118bl_log(void (*fun)(int, const char *, va_list), int level,
119    const char *fmt, ...)
120{
121	va_list ap;
122	int serrno = errno;
123
124	va_start(ap, fmt);
125	(*fun)(level, fmt, ap);
126	va_end(ap);
127	errno = serrno;
128}
129
130static int
131bl_init(bl_t b, bool srv)
132{
133	static int one = 1;
134	/* AF_UNIX address of local logger */
135	mode_t om;
136	int rv, serrno;
137	struct sockaddr_un *sun = &b->b_sun;
138
139#ifndef SOCK_NONBLOCK
140#define SOCK_NONBLOCK 0
141#endif
142#ifndef SOCK_CLOEXEC
143#define SOCK_CLOEXEC 0
144#endif
145#ifndef SOCK_NOSIGPIPE
146#define SOCK_NOSIGPIPE 0
147#endif
148
149	BL_LOCK(b);
150
151	if (b->b_fd == -1) {
152		b->b_fd = socket(PF_LOCAL,
153		    SOCK_DGRAM|SOCK_CLOEXEC|SOCK_NONBLOCK|SOCK_NOSIGPIPE, 0);
154		if (b->b_fd == -1) {
155			bl_log(b->b_fun, LOG_ERR, "%s: socket failed (%s)",
156			    __func__, strerror(errno));
157			BL_UNLOCK(b);
158			return -1;
159		}
160#if SOCK_CLOEXEC == 0
161		fcntl(b->b_fd, F_SETFD, FD_CLOEXEC);
162#endif
163#if SOCK_NONBLOCK == 0
164		fcntl(b->b_fd, F_SETFL, fcntl(b->b_fd, F_GETFL) | O_NONBLOCK);
165#endif
166#if SOCK_NOSIGPIPE == 0
167#ifdef SO_NOSIGPIPE
168		int o = 1;
169		setsockopt(b->b_fd, SOL_SOCKET, SO_NOSIGPIPE, &o, sizeof(o));
170#else
171		signal(SIGPIPE, SIG_IGN);
172#endif
173#endif
174	}
175
176	if (bl_isconnected(b)) {
177		BL_UNLOCK(b);
178		return 0;
179	}
180
181	/*
182	 * We try to connect anyway even when we are a server to verify
183	 * that no other server is listening to the socket. If we succeed
184	 * to connect and we are a server, someone else owns it.
185	 */
186	rv = connect(b->b_fd, (const void *)sun, (socklen_t)sizeof(*sun));
187	if (rv == 0) {
188		if (srv) {
189			bl_log(b->b_fun, LOG_ERR,
190			    "%s: another daemon is handling `%s'",
191			    __func__, sun->sun_path);
192			goto out;
193		}
194	} else {
195		if (!srv) {
196			/*
197			 * If the daemon is not running, we just try a
198			 * connect, so leave the socket alone until it does
199			 * and only log once.
200			 */
201			if (b->b_connected != 1) {
202				bl_log(b->b_fun, LOG_DEBUG,
203				    "%s: connect failed for `%s' (%s)",
204				    __func__, sun->sun_path, strerror(errno));
205				b->b_connected = 1;
206			}
207			BL_UNLOCK(b);
208			return -1;
209		}
210		bl_log(b->b_fun, LOG_DEBUG, "Connected to blacklist server",
211		    __func__);
212	}
213
214	if (srv) {
215		(void)unlink(sun->sun_path);
216		om = umask(0);
217		rv = bind(b->b_fd, (const void *)sun, (socklen_t)sizeof(*sun));
218		serrno = errno;
219		(void)umask(om);
220		errno = serrno;
221		if (rv == -1) {
222			bl_log(b->b_fun, LOG_ERR,
223			    "%s: bind failed for `%s' (%s)",
224			    __func__, sun->sun_path, strerror(errno));
225			goto out;
226		}
227	}
228
229	b->b_connected = 0;
230#define GOT_FD		1
231#if defined(LOCAL_CREDS)
232#define CRED_LEVEL	0
233#define	CRED_NAME	LOCAL_CREDS
234#define CRED_SC_UID	sc_euid
235#define CRED_SC_GID	sc_egid
236#define CRED_MESSAGE	SCM_CREDS
237#define CRED_SIZE	SOCKCREDSIZE(NGROUPS_MAX)
238#define CRED_TYPE	struct sockcred
239#define GOT_CRED	2
240#elif defined(SO_PASSCRED)
241#define CRED_LEVEL	SOL_SOCKET
242#define	CRED_NAME	SO_PASSCRED
243#define CRED_SC_UID	uid
244#define CRED_SC_GID	gid
245#define CRED_MESSAGE	SCM_CREDENTIALS
246#define CRED_SIZE	sizeof(struct ucred)
247#define CRED_TYPE	struct ucred
248#define GOT_CRED	2
249#else
250#define GOT_CRED	0
251/*
252 * getpeereid() and LOCAL_PEERCRED don't help here
253 * because we are not a stream socket!
254 */
255#define	CRED_SIZE	0
256#define CRED_TYPE	void * __unused
257#endif
258
259#ifdef CRED_LEVEL
260	if (setsockopt(b->b_fd, CRED_LEVEL, CRED_NAME,
261	    &one, (socklen_t)sizeof(one)) == -1) {
262		bl_log(b->b_fun, LOG_ERR, "%s: setsockopt %s "
263		    "failed (%s)", __func__, __STRING(CRED_NAME),
264		    strerror(errno));
265		goto out;
266	}
267#endif
268
269	BL_UNLOCK(b);
270	return 0;
271out:
272	bl_reset(b, true);
273	BL_UNLOCK(b);
274	return -1;
275}
276
277bl_t
278bl_create(bool srv, const char *path, void (*fun)(int, const char *, va_list))
279{
280	bl_t b = calloc(1, sizeof(*b));
281	if (b == NULL)
282		goto out;
283	b->b_fun = fun == NULL ? vsyslog : fun;
284	b->b_fd = -1;
285	b->b_connected = -1;
286	BL_INIT(b);
287
288	memset(&b->b_sun, 0, sizeof(b->b_sun));
289	b->b_sun.sun_family = AF_LOCAL;
290#ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
291	b->b_sun.sun_len = sizeof(b->b_sun);
292#endif
293	strlcpy(b->b_sun.sun_path,
294	    path ? path : _PATH_BLSOCK, sizeof(b->b_sun.sun_path));
295
296	bl_init(b, srv);
297	return b;
298out:
299	free(b);
300	bl_log(fun, LOG_ERR, "%s: malloc failed (%s)", __func__,
301	    strerror(errno));
302	return NULL;
303}
304
305void
306bl_destroy(bl_t b)
307{
308	bl_reset(b, false);
309	free(b);
310}
311
312static int
313bl_getsock(bl_t b, struct sockaddr_storage *ss, const struct sockaddr *sa,
314    socklen_t slen, const char *ctx)
315{
316	uint8_t family;
317
318	memset(ss, 0, sizeof(*ss));
319
320	switch (slen) {
321	case 0:
322		return 0;
323	case sizeof(struct sockaddr_in):
324		family = AF_INET;
325		break;
326	case sizeof(struct sockaddr_in6):
327		family = AF_INET6;
328		break;
329	default:
330		bl_log(b->b_fun, LOG_ERR, "%s: invalid socket len %u (%s)",
331		    __func__, (unsigned)slen, ctx);
332		errno = EINVAL;
333		return -1;
334	}
335
336	memcpy(ss, sa, slen);
337
338	if (ss->ss_family != family) {
339		bl_log(b->b_fun, LOG_INFO,
340		    "%s: correcting socket family %d to %d (%s)",
341		    __func__, ss->ss_family, family, ctx);
342		ss->ss_family = family;
343	}
344
345#ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
346	if (ss->ss_len != slen) {
347		bl_log(b->b_fun, LOG_INFO,
348		    "%s: correcting socket len %u to %u (%s)",
349		    __func__, ss->ss_len, (unsigned)slen, ctx);
350		ss->ss_len = (uint8_t)slen;
351	}
352#endif
353	return 0;
354}
355
356int
357bl_send(bl_t b, bl_type_t e, int pfd, const struct sockaddr *sa,
358    socklen_t slen, const char *ctx)
359{
360	struct msghdr   msg;
361	struct iovec    iov;
362	union {
363		char ctrl[CMSG_SPACE(sizeof(int))];
364		uint32_t fd;
365	} ua;
366	struct cmsghdr *cmsg;
367	union {
368		bl_message_t bl;
369		char buf[512];
370	} ub;
371	size_t ctxlen, tried;
372#define NTRIES	5
373
374	ctxlen = strlen(ctx);
375	if (ctxlen > 128)
376		ctxlen = 128;
377
378	iov.iov_base = ub.buf;
379	iov.iov_len = sizeof(bl_message_t) + ctxlen;
380	ub.bl.bl_len = (uint32_t)iov.iov_len;
381	ub.bl.bl_version = BL_VERSION;
382	ub.bl.bl_type = (uint32_t)e;
383
384	if (bl_getsock(b, &ub.bl.bl_ss, sa, slen, ctx) == -1)
385		return -1;
386
387
388	ub.bl.bl_salen = slen;
389	memcpy(ub.bl.bl_data, ctx, ctxlen);
390
391	msg.msg_name = NULL;
392	msg.msg_namelen = 0;
393	msg.msg_iov = &iov;
394	msg.msg_iovlen = 1;
395	msg.msg_flags = 0;
396
397	msg.msg_control = ua.ctrl;
398	msg.msg_controllen = sizeof(ua.ctrl);
399
400	cmsg = CMSG_FIRSTHDR(&msg);
401	cmsg->cmsg_len = CMSG_LEN(sizeof(int));
402	cmsg->cmsg_level = SOL_SOCKET;
403	cmsg->cmsg_type = SCM_RIGHTS;
404
405	memcpy(CMSG_DATA(cmsg), &pfd, sizeof(pfd));
406
407	tried = 0;
408again:
409	if (bl_init(b, false) == -1)
410		return -1;
411
412	if ((sendmsg(b->b_fd, &msg, 0) == -1) && tried++ < NTRIES) {
413		bl_reset(b, false);
414		goto again;
415	}
416	return tried >= NTRIES ? -1 : 0;
417}
418
419bl_info_t *
420bl_recv(bl_t b)
421{
422        struct msghdr   msg;
423        struct iovec    iov;
424	union {
425		char ctrl[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(CRED_SIZE)];
426		uint32_t fd;
427		CRED_TYPE sc;
428	} ua;
429	struct cmsghdr *cmsg;
430	CRED_TYPE *sc;
431	union {
432		bl_message_t bl;
433		char buf[512];
434	} ub;
435	int got;
436	ssize_t rlen;
437	bl_info_t *bi = &b->b_info;
438
439	got = 0;
440	memset(bi, 0, sizeof(*bi));
441
442	iov.iov_base = ub.buf;
443	iov.iov_len = sizeof(ub);
444
445	msg.msg_name = NULL;
446	msg.msg_namelen = 0;
447	msg.msg_iov = &iov;
448	msg.msg_iovlen = 1;
449	msg.msg_flags = 0;
450
451	msg.msg_control = ua.ctrl;
452	msg.msg_controllen = sizeof(ua.ctrl) + 100;
453
454        rlen = recvmsg(b->b_fd, &msg, 0);
455        if (rlen == -1) {
456		bl_log(b->b_fun, LOG_ERR, "%s: recvmsg failed (%s)", __func__,
457		    strerror(errno));
458		return NULL;
459        }
460
461	for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
462		if (cmsg->cmsg_level != SOL_SOCKET) {
463			bl_log(b->b_fun, LOG_ERR,
464			    "%s: unexpected cmsg_level %d",
465			    __func__, cmsg->cmsg_level);
466			continue;
467		}
468		switch (cmsg->cmsg_type) {
469		case SCM_RIGHTS:
470			if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) {
471				bl_log(b->b_fun, LOG_ERR,
472				    "%s: unexpected cmsg_len %d != %zu",
473				    __func__, cmsg->cmsg_len,
474				    CMSG_LEN(2 * sizeof(int)));
475				continue;
476			}
477			memcpy(&bi->bi_fd, CMSG_DATA(cmsg), sizeof(bi->bi_fd));
478			got |= GOT_FD;
479			break;
480#ifdef CRED_MESSAGE
481		case CRED_MESSAGE:
482			sc = (void *)CMSG_DATA(cmsg);
483			bi->bi_uid = sc->CRED_SC_UID;
484			bi->bi_gid = sc->CRED_SC_GID;
485			got |= GOT_CRED;
486			break;
487#endif
488		default:
489			bl_log(b->b_fun, LOG_ERR,
490			    "%s: unexpected cmsg_type %d",
491			    __func__, cmsg->cmsg_type);
492			continue;
493		}
494
495	}
496
497	if (got != (GOT_CRED|GOT_FD)) {
498		bl_log(b->b_fun, LOG_ERR, "message missing %s %s",
499#if GOT_CRED != 0
500		    (got & GOT_CRED) == 0 ? "cred" :
501#endif
502		    "", (got & GOT_FD) == 0 ? "fd" : "");
503
504		return NULL;
505	}
506
507	if ((size_t)rlen <= sizeof(ub.bl)) {
508		bl_log(b->b_fun, LOG_ERR, "message too short %zd", rlen);
509		return NULL;
510	}
511
512	if (ub.bl.bl_version != BL_VERSION) {
513		bl_log(b->b_fun, LOG_ERR, "bad version %d", ub.bl.bl_version);
514		return NULL;
515	}
516
517	bi->bi_type = ub.bl.bl_type;
518	bi->bi_slen = ub.bl.bl_salen;
519	bi->bi_ss = ub.bl.bl_ss;
520#ifndef CRED_MESSAGE
521	bi->bi_uid = -1;
522	bi->bi_gid = -1;
523#endif
524	strlcpy(bi->bi_msg, ub.bl.bl_data, MIN(sizeof(bi->bi_msg),
525	    ((size_t)rlen - sizeof(ub.bl) + 1)));
526	return bi;
527}
528