1/*-
2 * Copyright (c) 2009-2010 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 *
29 * $P4: //depot/projects/trustedbsd/openbsm/bin/auditdistd/proto.c#1 $
30 */
31
32#include <sys/types.h>
33#include <sys/queue.h>
34#include <sys/socket.h>
35
36#include <errno.h>
37#include <stdint.h>
38#include <string.h>
39#include <strings.h>
40
41#include "pjdlog.h"
42#include "proto.h"
43#include "proto_impl.h"
44
45#define	PROTO_CONN_MAGIC	0x907041c
46struct proto_conn {
47	int		 pc_magic;
48	struct proto	*pc_proto;
49	void		*pc_ctx;
50	int		 pc_side;
51#define	PROTO_SIDE_CLIENT		0
52#define	PROTO_SIDE_SERVER_LISTEN	1
53#define	PROTO_SIDE_SERVER_WORK		2
54};
55
56static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
57
58void
59proto_register(struct proto *proto, bool isdefault)
60{
61	static bool seen_default = false;
62
63	if (!isdefault)
64		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
65	else {
66		PJDLOG_ASSERT(!seen_default);
67		seen_default = true;
68		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
69	}
70}
71
72static struct proto_conn *
73proto_alloc(struct proto *proto, int side)
74{
75	struct proto_conn *conn;
76
77	PJDLOG_ASSERT(proto != NULL);
78	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
79	    side == PROTO_SIDE_SERVER_LISTEN ||
80	    side == PROTO_SIDE_SERVER_WORK);
81
82	conn = malloc(sizeof(*conn));
83	if (conn != NULL) {
84		conn->pc_proto = proto;
85		conn->pc_side = side;
86		conn->pc_magic = PROTO_CONN_MAGIC;
87	}
88	return (conn);
89}
90
91static void
92proto_free(struct proto_conn *conn)
93{
94
95	PJDLOG_ASSERT(conn != NULL);
96	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
97	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
98	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
99	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
100	PJDLOG_ASSERT(conn->pc_proto != NULL);
101
102	bzero(conn, sizeof(*conn));
103	free(conn);
104}
105
106static int
107proto_common_setup(const char *srcaddr, const char *dstaddr, int timeout,
108    int side, struct proto_conn **connp)
109{
110	struct proto *proto;
111	struct proto_conn *conn;
112	void *ctx;
113	int ret;
114
115	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
116	    side == PROTO_SIDE_SERVER_LISTEN);
117
118	TAILQ_FOREACH(proto, &protos, prt_next) {
119		if (side == PROTO_SIDE_CLIENT) {
120			if (proto->prt_connect == NULL) {
121				ret = -1;
122			} else {
123				ret = proto->prt_connect(srcaddr, dstaddr,
124				    timeout, &ctx);
125			}
126		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
127			if (proto->prt_server == NULL)
128				ret = -1;
129			else
130				ret = proto->prt_server(dstaddr, &ctx);
131		}
132		/*
133		 * ret == 0  - success
134		 * ret == -1 - dstaddr is not for this protocol
135		 * ret > 0   - right protocol, but an error occured
136		 */
137		if (ret >= 0)
138			break;
139	}
140	if (proto == NULL) {
141		/* Unrecognized address. */
142		errno = EINVAL;
143		return (-1);
144	}
145	if (ret > 0) {
146		/* An error occured. */
147		errno = ret;
148		return (-1);
149	}
150	conn = proto_alloc(proto, side);
151	if (conn == NULL) {
152		if (proto->prt_close != NULL)
153			proto->prt_close(ctx);
154		errno = ENOMEM;
155		return (-1);
156	}
157	conn->pc_ctx = ctx;
158	*connp = conn;
159
160	return (0);
161}
162
163int
164proto_connect(const char *srcaddr, const char *dstaddr, int timeout,
165    struct proto_conn **connp)
166{
167
168	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
169	PJDLOG_ASSERT(dstaddr != NULL);
170	PJDLOG_ASSERT(timeout >= -1);
171
172	return (proto_common_setup(srcaddr, dstaddr, timeout,
173	    PROTO_SIDE_CLIENT, connp));
174}
175
176int
177proto_connect_wait(struct proto_conn *conn, int timeout)
178{
179	int error;
180
181	PJDLOG_ASSERT(conn != NULL);
182	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
183	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
184	PJDLOG_ASSERT(conn->pc_proto != NULL);
185	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
186	PJDLOG_ASSERT(timeout >= 0);
187
188	error = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
189	if (error != 0) {
190		errno = error;
191		return (-1);
192	}
193
194	return (0);
195}
196
197int
198proto_server(const char *addr, struct proto_conn **connp)
199{
200
201	PJDLOG_ASSERT(addr != NULL);
202
203	return (proto_common_setup(NULL, addr, -1, PROTO_SIDE_SERVER_LISTEN,
204	    connp));
205}
206
207int
208proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
209{
210	struct proto_conn *newconn;
211	int error;
212
213	PJDLOG_ASSERT(conn != NULL);
214	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
215	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
216	PJDLOG_ASSERT(conn->pc_proto != NULL);
217	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
218
219	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
220	if (newconn == NULL)
221		return (-1);
222
223	error = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
224	if (error != 0) {
225		proto_free(newconn);
226		errno = error;
227		return (-1);
228	}
229
230	*newconnp = newconn;
231
232	return (0);
233}
234
235int
236proto_send(const struct proto_conn *conn, const void *data, size_t size)
237{
238	int error;
239
240	PJDLOG_ASSERT(conn != NULL);
241	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
242	PJDLOG_ASSERT(conn->pc_proto != NULL);
243	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
244
245	error = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
246	if (error != 0) {
247		errno = error;
248		return (-1);
249	}
250	return (0);
251}
252
253int
254proto_recv(const struct proto_conn *conn, void *data, size_t size)
255{
256	int error;
257
258	PJDLOG_ASSERT(conn != NULL);
259	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
260	PJDLOG_ASSERT(conn->pc_proto != NULL);
261	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
262
263	error = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
264	if (error != 0) {
265		errno = error;
266		return (-1);
267	}
268	return (0);
269}
270
271int
272proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
273{
274	const char *protoname;
275	int error, fd;
276
277	PJDLOG_ASSERT(conn != NULL);
278	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
279	PJDLOG_ASSERT(conn->pc_proto != NULL);
280	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
281	PJDLOG_ASSERT(mconn != NULL);
282	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
283	PJDLOG_ASSERT(mconn->pc_proto != NULL);
284	fd = proto_descriptor(mconn);
285	PJDLOG_ASSERT(fd >= 0);
286	protoname = mconn->pc_proto->prt_name;
287	PJDLOG_ASSERT(protoname != NULL);
288
289	error = conn->pc_proto->prt_send(conn->pc_ctx,
290	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
291	proto_close(mconn);
292	if (error != 0) {
293		errno = error;
294		return (-1);
295	}
296	return (0);
297}
298
299int
300proto_wrap(const char *protoname, bool client, int fd,
301    struct proto_conn **newconnp)
302{
303	struct proto *proto;
304	struct proto_conn *newconn;
305	int error;
306
307	TAILQ_FOREACH(proto, &protos, prt_next) {
308		if (strcmp(proto->prt_name, protoname) == 0)
309			break;
310	}
311	if (proto == NULL) {
312		errno = EINVAL;
313		return (-1);
314	}
315
316	newconn = proto_alloc(proto,
317	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
318	if (newconn == NULL)
319		return (-1);
320	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
321	error = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
322	if (error != 0) {
323		proto_free(newconn);
324		errno = error;
325		return (-1);
326	}
327
328	*newconnp = newconn;
329
330	return (0);
331}
332
333int
334proto_connection_recv(const struct proto_conn *conn, bool client,
335    struct proto_conn **newconnp)
336{
337	char protoname[128];
338	int error, fd;
339
340	PJDLOG_ASSERT(conn != NULL);
341	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
342	PJDLOG_ASSERT(conn->pc_proto != NULL);
343	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
344	PJDLOG_ASSERT(newconnp != NULL);
345
346	bzero(protoname, sizeof(protoname));
347
348	error = conn->pc_proto->prt_recv(conn->pc_ctx,
349	    (unsigned char *)protoname, sizeof(protoname) - 1, &fd);
350	if (error != 0) {
351		errno = error;
352		return (-1);
353	}
354
355	PJDLOG_ASSERT(fd >= 0);
356
357	return (proto_wrap(protoname, client, fd, newconnp));
358}
359
360int
361proto_descriptor(const struct proto_conn *conn)
362{
363
364	PJDLOG_ASSERT(conn != NULL);
365	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
366	PJDLOG_ASSERT(conn->pc_proto != NULL);
367	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
368
369	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
370}
371
372bool
373proto_address_match(const struct proto_conn *conn, const char *addr)
374{
375
376	PJDLOG_ASSERT(conn != NULL);
377	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
378	PJDLOG_ASSERT(conn->pc_proto != NULL);
379	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
380
381	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
382}
383
384void
385proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
386{
387
388	PJDLOG_ASSERT(conn != NULL);
389	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
390	PJDLOG_ASSERT(conn->pc_proto != NULL);
391	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
392
393	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
394}
395
396void
397proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
398{
399
400	PJDLOG_ASSERT(conn != NULL);
401	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
402	PJDLOG_ASSERT(conn->pc_proto != NULL);
403	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
404
405	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
406}
407
408int
409proto_timeout(const struct proto_conn *conn, int timeout)
410{
411	struct timeval tv;
412	int fd;
413
414	PJDLOG_ASSERT(conn != NULL);
415	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
416	PJDLOG_ASSERT(conn->pc_proto != NULL);
417
418	fd = proto_descriptor(conn);
419	if (fd < 0)
420		return (-1);
421
422	tv.tv_sec = timeout;
423	tv.tv_usec = 0;
424	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0)
425		return (-1);
426	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
427		return (-1);
428
429	return (0);
430}
431
432void
433proto_close(struct proto_conn *conn)
434{
435
436	PJDLOG_ASSERT(conn != NULL);
437	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
438	PJDLOG_ASSERT(conn->pc_proto != NULL);
439	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
440
441	conn->pc_proto->prt_close(conn->pc_ctx);
442	proto_free(conn);
443}
444
445int
446proto_exec(int argc, char *argv[])
447{
448	struct proto *proto;
449	int error;
450
451	if (argc == 0) {
452		errno = EINVAL;
453		return (-1);
454	}
455	TAILQ_FOREACH(proto, &protos, prt_next) {
456		if (strcmp(proto->prt_name, argv[0]) == 0)
457			break;
458	}
459	if (proto == NULL) {
460		errno = EINVAL;
461		return (-1);
462	}
463	if (proto->prt_exec == NULL) {
464		errno = EOPNOTSUPP;
465		return (-1);
466	}
467	error = proto->prt_exec(argc, argv);
468	if (error != 0) {
469		errno = error;
470		return (-1);
471	}
472	/* NOTREACHED */
473	return (0);
474}
475
476struct proto_nvpair {
477	char	*pnv_name;
478	char	*pnv_value;
479	TAILQ_ENTRY(proto_nvpair) pnv_next;
480};
481
482static TAILQ_HEAD(, proto_nvpair) proto_nvpairs =
483    TAILQ_HEAD_INITIALIZER(proto_nvpairs);
484
485int
486proto_set(const char *name, const char *value)
487{
488	struct proto_nvpair *pnv;
489
490	TAILQ_FOREACH(pnv, &proto_nvpairs, pnv_next) {
491		if (strcmp(pnv->pnv_name, name) == 0)
492			break;
493	}
494	if (pnv != NULL) {
495		TAILQ_REMOVE(&proto_nvpairs, pnv, pnv_next);
496		free(pnv->pnv_value);
497	} else {
498		pnv = malloc(sizeof(*pnv));
499		if (pnv == NULL)
500			return (-1);
501		pnv->pnv_name = strdup(name);
502		if (pnv->pnv_name == NULL) {
503			free(pnv);
504			return (-1);
505		}
506	}
507	pnv->pnv_value = strdup(value);
508	if (pnv->pnv_value == NULL) {
509		free(pnv->pnv_name);
510		free(pnv);
511		return (-1);
512	}
513	TAILQ_INSERT_TAIL(&proto_nvpairs, pnv, pnv_next);
514	return (0);
515}
516
517const char *
518proto_get(const char *name)
519{
520	struct proto_nvpair *pnv;
521
522	TAILQ_FOREACH(pnv, &proto_nvpairs, pnv_next) {
523		if (strcmp(pnv->pnv_name, name) == 0)
524			break;
525	}
526	if (pnv != NULL)
527		return (pnv->pnv_value);
528	return (NULL);
529}
530