1/*-
2 * Copyright (c) 2009-2010 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 */
29
30#include <sys/types.h>
31#include <sys/queue.h>
32#include <sys/socket.h>
33
34#include <errno.h>
35#include <stdint.h>
36#include <string.h>
37#include <strings.h>
38
39#include "pjdlog.h"
40#include "proto.h"
41#include "proto_impl.h"
42
43#define	PROTO_CONN_MAGIC	0x907041c
44struct proto_conn {
45	int		 pc_magic;
46	struct proto	*pc_proto;
47	void		*pc_ctx;
48	int		 pc_side;
49#define	PROTO_SIDE_CLIENT		0
50#define	PROTO_SIDE_SERVER_LISTEN	1
51#define	PROTO_SIDE_SERVER_WORK		2
52};
53
54static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
55
56void
57proto_register(struct proto *proto, bool isdefault)
58{
59	static bool seen_default = false;
60
61	if (!isdefault)
62		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
63	else {
64		PJDLOG_ASSERT(!seen_default);
65		seen_default = true;
66		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
67	}
68}
69
70static struct proto_conn *
71proto_alloc(struct proto *proto, int side)
72{
73	struct proto_conn *conn;
74
75	PJDLOG_ASSERT(proto != NULL);
76	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
77	    side == PROTO_SIDE_SERVER_LISTEN ||
78	    side == PROTO_SIDE_SERVER_WORK);
79
80	conn = malloc(sizeof(*conn));
81	if (conn != NULL) {
82		conn->pc_proto = proto;
83		conn->pc_side = side;
84		conn->pc_magic = PROTO_CONN_MAGIC;
85	}
86	return (conn);
87}
88
89static void
90proto_free(struct proto_conn *conn)
91{
92
93	PJDLOG_ASSERT(conn != NULL);
94	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
95	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
96	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
97	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
98	PJDLOG_ASSERT(conn->pc_proto != NULL);
99
100	bzero(conn, sizeof(*conn));
101	free(conn);
102}
103
104static int
105proto_common_setup(const char *srcaddr, const char *dstaddr, int timeout,
106    int side, struct proto_conn **connp)
107{
108	struct proto *proto;
109	struct proto_conn *conn;
110	void *ctx;
111	int ret;
112
113	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
114	    side == PROTO_SIDE_SERVER_LISTEN);
115
116	TAILQ_FOREACH(proto, &protos, prt_next) {
117		if (side == PROTO_SIDE_CLIENT) {
118			if (proto->prt_connect == NULL) {
119				ret = -1;
120			} else {
121				ret = proto->prt_connect(srcaddr, dstaddr,
122				    timeout, &ctx);
123			}
124		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
125			if (proto->prt_server == NULL)
126				ret = -1;
127			else
128				ret = proto->prt_server(dstaddr, &ctx);
129		}
130		/*
131		 * ret == 0  - success
132		 * ret == -1 - dstaddr is not for this protocol
133		 * ret > 0   - right protocol, but an error occured
134		 */
135		if (ret >= 0)
136			break;
137	}
138	if (proto == NULL) {
139		/* Unrecognized address. */
140		errno = EINVAL;
141		return (-1);
142	}
143	if (ret > 0) {
144		/* An error occured. */
145		errno = ret;
146		return (-1);
147	}
148	conn = proto_alloc(proto, side);
149	if (conn == NULL) {
150		if (proto->prt_close != NULL)
151			proto->prt_close(ctx);
152		errno = ENOMEM;
153		return (-1);
154	}
155	conn->pc_ctx = ctx;
156	*connp = conn;
157
158	return (0);
159}
160
161int
162proto_connect(const char *srcaddr, const char *dstaddr, int timeout,
163    struct proto_conn **connp)
164{
165
166	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
167	PJDLOG_ASSERT(dstaddr != NULL);
168	PJDLOG_ASSERT(timeout >= -1);
169
170	return (proto_common_setup(srcaddr, dstaddr, timeout,
171	    PROTO_SIDE_CLIENT, connp));
172}
173
174int
175proto_connect_wait(struct proto_conn *conn, int timeout)
176{
177	int error;
178
179	PJDLOG_ASSERT(conn != NULL);
180	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
181	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
182	PJDLOG_ASSERT(conn->pc_proto != NULL);
183	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
184	PJDLOG_ASSERT(timeout >= 0);
185
186	error = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
187	if (error != 0) {
188		errno = error;
189		return (-1);
190	}
191
192	return (0);
193}
194
195int
196proto_server(const char *addr, struct proto_conn **connp)
197{
198
199	PJDLOG_ASSERT(addr != NULL);
200
201	return (proto_common_setup(NULL, addr, -1, PROTO_SIDE_SERVER_LISTEN,
202	    connp));
203}
204
205int
206proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
207{
208	struct proto_conn *newconn;
209	int error;
210
211	PJDLOG_ASSERT(conn != NULL);
212	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
213	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
214	PJDLOG_ASSERT(conn->pc_proto != NULL);
215	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
216
217	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
218	if (newconn == NULL)
219		return (-1);
220
221	error = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
222	if (error != 0) {
223		proto_free(newconn);
224		errno = error;
225		return (-1);
226	}
227
228	*newconnp = newconn;
229
230	return (0);
231}
232
233int
234proto_send(const struct proto_conn *conn, const void *data, size_t size)
235{
236	int error;
237
238	PJDLOG_ASSERT(conn != NULL);
239	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
240	PJDLOG_ASSERT(conn->pc_proto != NULL);
241	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
242
243	error = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
244	if (error != 0) {
245		errno = error;
246		return (-1);
247	}
248	return (0);
249}
250
251int
252proto_recv(const struct proto_conn *conn, void *data, size_t size)
253{
254	int error;
255
256	PJDLOG_ASSERT(conn != NULL);
257	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
258	PJDLOG_ASSERT(conn->pc_proto != NULL);
259	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
260
261	error = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
262	if (error != 0) {
263		errno = error;
264		return (-1);
265	}
266	return (0);
267}
268
269int
270proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
271{
272	const char *protoname;
273	int error, fd;
274
275	PJDLOG_ASSERT(conn != NULL);
276	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
277	PJDLOG_ASSERT(conn->pc_proto != NULL);
278	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
279	PJDLOG_ASSERT(mconn != NULL);
280	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
281	PJDLOG_ASSERT(mconn->pc_proto != NULL);
282	fd = proto_descriptor(mconn);
283	PJDLOG_ASSERT(fd >= 0);
284	protoname = mconn->pc_proto->prt_name;
285	PJDLOG_ASSERT(protoname != NULL);
286
287	error = conn->pc_proto->prt_send(conn->pc_ctx,
288	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
289	proto_close(mconn);
290	if (error != 0) {
291		errno = error;
292		return (-1);
293	}
294	return (0);
295}
296
297int
298proto_wrap(const char *protoname, bool client, int fd,
299    struct proto_conn **newconnp)
300{
301	struct proto *proto;
302	struct proto_conn *newconn;
303	int error;
304
305	TAILQ_FOREACH(proto, &protos, prt_next) {
306		if (strcmp(proto->prt_name, protoname) == 0)
307			break;
308	}
309	if (proto == NULL) {
310		errno = EINVAL;
311		return (-1);
312	}
313
314	newconn = proto_alloc(proto,
315	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
316	if (newconn == NULL)
317		return (-1);
318	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
319	error = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
320	if (error != 0) {
321		proto_free(newconn);
322		errno = error;
323		return (-1);
324	}
325
326	*newconnp = newconn;
327
328	return (0);
329}
330
331int
332proto_connection_recv(const struct proto_conn *conn, bool client,
333    struct proto_conn **newconnp)
334{
335	char protoname[128];
336	int error, fd;
337
338	PJDLOG_ASSERT(conn != NULL);
339	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
340	PJDLOG_ASSERT(conn->pc_proto != NULL);
341	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
342	PJDLOG_ASSERT(newconnp != NULL);
343
344	bzero(protoname, sizeof(protoname));
345
346	error = conn->pc_proto->prt_recv(conn->pc_ctx,
347	    (unsigned char *)protoname, sizeof(protoname) - 1, &fd);
348	if (error != 0) {
349		errno = error;
350		return (-1);
351	}
352
353	PJDLOG_ASSERT(fd >= 0);
354
355	return (proto_wrap(protoname, client, fd, newconnp));
356}
357
358int
359proto_descriptor(const struct proto_conn *conn)
360{
361
362	PJDLOG_ASSERT(conn != NULL);
363	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
364	PJDLOG_ASSERT(conn->pc_proto != NULL);
365	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
366
367	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
368}
369
370bool
371proto_address_match(const struct proto_conn *conn, const char *addr)
372{
373
374	PJDLOG_ASSERT(conn != NULL);
375	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
376	PJDLOG_ASSERT(conn->pc_proto != NULL);
377	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
378
379	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
380}
381
382void
383proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
384{
385
386	PJDLOG_ASSERT(conn != NULL);
387	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
388	PJDLOG_ASSERT(conn->pc_proto != NULL);
389	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
390
391	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
392}
393
394void
395proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
396{
397
398	PJDLOG_ASSERT(conn != NULL);
399	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
400	PJDLOG_ASSERT(conn->pc_proto != NULL);
401	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
402
403	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
404}
405
406int
407proto_timeout(const struct proto_conn *conn, int timeout)
408{
409	struct timeval tv;
410	int fd;
411
412	PJDLOG_ASSERT(conn != NULL);
413	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
414	PJDLOG_ASSERT(conn->pc_proto != NULL);
415
416	fd = proto_descriptor(conn);
417	if (fd < 0)
418		return (-1);
419
420	tv.tv_sec = timeout;
421	tv.tv_usec = 0;
422	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0)
423		return (-1);
424	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
425		return (-1);
426
427	return (0);
428}
429
430void
431proto_close(struct proto_conn *conn)
432{
433
434	PJDLOG_ASSERT(conn != NULL);
435	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
436	PJDLOG_ASSERT(conn->pc_proto != NULL);
437	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
438
439	conn->pc_proto->prt_close(conn->pc_ctx);
440	proto_free(conn);
441}
442
443int
444proto_exec(int argc, char *argv[])
445{
446	struct proto *proto;
447	int error;
448
449	if (argc == 0) {
450		errno = EINVAL;
451		return (-1);
452	}
453	TAILQ_FOREACH(proto, &protos, prt_next) {
454		if (strcmp(proto->prt_name, argv[0]) == 0)
455			break;
456	}
457	if (proto == NULL) {
458		errno = EINVAL;
459		return (-1);
460	}
461	if (proto->prt_exec == NULL) {
462		errno = EOPNOTSUPP;
463		return (-1);
464	}
465	error = proto->prt_exec(argc, argv);
466	if (error != 0) {
467		errno = error;
468		return (-1);
469	}
470	/* NOTREACHED */
471	return (0);
472}
473
474struct proto_nvpair {
475	char	*pnv_name;
476	char	*pnv_value;
477	TAILQ_ENTRY(proto_nvpair) pnv_next;
478};
479
480static TAILQ_HEAD(, proto_nvpair) proto_nvpairs =
481    TAILQ_HEAD_INITIALIZER(proto_nvpairs);
482
483int
484proto_set(const char *name, const char *value)
485{
486	struct proto_nvpair *pnv;
487
488	TAILQ_FOREACH(pnv, &proto_nvpairs, pnv_next) {
489		if (strcmp(pnv->pnv_name, name) == 0)
490			break;
491	}
492	if (pnv != NULL) {
493		TAILQ_REMOVE(&proto_nvpairs, pnv, pnv_next);
494		free(pnv->pnv_value);
495	} else {
496		pnv = malloc(sizeof(*pnv));
497		if (pnv == NULL)
498			return (-1);
499		pnv->pnv_name = strdup(name);
500		if (pnv->pnv_name == NULL) {
501			free(pnv);
502			return (-1);
503		}
504	}
505	pnv->pnv_value = strdup(value);
506	if (pnv->pnv_value == NULL) {
507		free(pnv->pnv_name);
508		free(pnv);
509		return (-1);
510	}
511	TAILQ_INSERT_TAIL(&proto_nvpairs, pnv, pnv_next);
512	return (0);
513}
514
515const char *
516proto_get(const char *name)
517{
518	struct proto_nvpair *pnv;
519
520	TAILQ_FOREACH(pnv, &proto_nvpairs, pnv_next) {
521		if (strcmp(pnv->pnv_name, name) == 0)
522			break;
523	}
524	if (pnv != NULL)
525		return (pnv->pnv_value);
526	return (NULL);
527}
528