1204076Spjd/*-
2204076Spjd * Copyright (c) 2009-2010 The FreeBSD Foundation
3204076Spjd * All rights reserved.
4204076Spjd *
5204076Spjd * This software was developed by Pawel Jakub Dawidek under sponsorship from
6204076Spjd * the FreeBSD Foundation.
7204076Spjd *
8204076Spjd * Redistribution and use in source and binary forms, with or without
9204076Spjd * modification, are permitted provided that the following conditions
10204076Spjd * are met:
11204076Spjd * 1. Redistributions of source code must retain the above copyright
12204076Spjd *    notice, this list of conditions and the following disclaimer.
13204076Spjd * 2. Redistributions in binary form must reproduce the above copyright
14204076Spjd *    notice, this list of conditions and the following disclaimer in the
15204076Spjd *    documentation and/or other materials provided with the distribution.
16204076Spjd *
17204076Spjd * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18204076Spjd * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19204076Spjd * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20204076Spjd * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21204076Spjd * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22204076Spjd * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23204076Spjd * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24204076Spjd * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25204076Spjd * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26204076Spjd * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27204076Spjd * SUCH DAMAGE.
28204076Spjd */
29204076Spjd
30204076Spjd#include <sys/cdefs.h>
31204076Spjd__FBSDID("$FreeBSD$");
32204076Spjd
33207371Spjd#include <sys/types.h>
34204076Spjd#include <sys/queue.h>
35207371Spjd#include <sys/socket.h>
36204076Spjd
37204076Spjd#include <errno.h>
38204076Spjd#include <stdint.h>
39218194Spjd#include <string.h>
40218191Spjd#include <strings.h>
41204076Spjd
42218138Spjd#include "pjdlog.h"
43204076Spjd#include "proto.h"
44204076Spjd#include "proto_impl.h"
45204076Spjd
46204076Spjd#define	PROTO_CONN_MAGIC	0x907041c
47204076Spjdstruct proto_conn {
48219873Spjd	int		 pc_magic;
49219873Spjd	struct proto	*pc_proto;
50219873Spjd	void		*pc_ctx;
51219873Spjd	int		 pc_side;
52204076Spjd#define	PROTO_SIDE_CLIENT		0
53204076Spjd#define	PROTO_SIDE_SERVER_LISTEN	1
54204076Spjd#define	PROTO_SIDE_SERVER_WORK		2
55204076Spjd};
56204076Spjd
57219873Spjdstatic TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
58204076Spjd
59204076Spjdvoid
60219873Spjdproto_register(struct proto *proto, bool isdefault)
61204076Spjd{
62210869Spjd	static bool seen_default = false;
63204076Spjd
64210869Spjd	if (!isdefault)
65219873Spjd		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
66210869Spjd	else {
67218138Spjd		PJDLOG_ASSERT(!seen_default);
68210869Spjd		seen_default = true;
69219873Spjd		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
70210869Spjd	}
71204076Spjd}
72204076Spjd
73218191Spjdstatic struct proto_conn *
74219873Spjdproto_alloc(struct proto *proto, int side)
75218191Spjd{
76218191Spjd	struct proto_conn *conn;
77218191Spjd
78218191Spjd	PJDLOG_ASSERT(proto != NULL);
79218191Spjd	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
80218191Spjd	    side == PROTO_SIDE_SERVER_LISTEN ||
81218191Spjd	    side == PROTO_SIDE_SERVER_WORK);
82218191Spjd
83218191Spjd	conn = malloc(sizeof(*conn));
84218191Spjd	if (conn != NULL) {
85218191Spjd		conn->pc_proto = proto;
86218191Spjd		conn->pc_side = side;
87218191Spjd		conn->pc_magic = PROTO_CONN_MAGIC;
88218191Spjd	}
89218191Spjd	return (conn);
90218191Spjd}
91218191Spjd
92218191Spjdstatic void
93218191Spjdproto_free(struct proto_conn *conn)
94218191Spjd{
95218191Spjd
96218191Spjd	PJDLOG_ASSERT(conn != NULL);
97218191Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
98218191Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
99218191Spjd	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
100218191Spjd	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
101218191Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
102218191Spjd
103218191Spjd	bzero(conn, sizeof(*conn));
104218191Spjd	free(conn);
105218191Spjd}
106218191Spjd
107204076Spjdstatic int
108219818Spjdproto_common_setup(const char *srcaddr, const char *dstaddr,
109219818Spjd    struct proto_conn **connp, int side)
110204076Spjd{
111219873Spjd	struct proto *proto;
112204076Spjd	struct proto_conn *conn;
113204076Spjd	void *ctx;
114204076Spjd	int ret;
115204076Spjd
116218191Spjd	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
117218191Spjd	    side == PROTO_SIDE_SERVER_LISTEN);
118204076Spjd
119219873Spjd	TAILQ_FOREACH(proto, &protos, prt_next) {
120218185Spjd		if (side == PROTO_SIDE_CLIENT) {
121219873Spjd			if (proto->prt_client == NULL)
122218185Spjd				ret = -1;
123218185Spjd			else
124219873Spjd				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
125218185Spjd		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
126219873Spjd			if (proto->prt_server == NULL)
127218185Spjd				ret = -1;
128218185Spjd			else
129219873Spjd				ret = proto->prt_server(dstaddr, &ctx);
130218185Spjd		}
131204076Spjd		/*
132204076Spjd		 * ret == 0  - success
133219818Spjd		 * ret == -1 - dstaddr is not for this protocol
134231017Strociny		 * ret > 0   - right protocol, but an error occurred
135204076Spjd		 */
136204076Spjd		if (ret >= 0)
137204076Spjd			break;
138204076Spjd	}
139204076Spjd	if (proto == NULL) {
140204076Spjd		/* Unrecognized address. */
141204076Spjd		errno = EINVAL;
142204076Spjd		return (-1);
143204076Spjd	}
144204076Spjd	if (ret > 0) {
145231017Strociny		/* An error occurred. */
146204076Spjd		errno = ret;
147204076Spjd		return (-1);
148204076Spjd	}
149218191Spjd	conn = proto_alloc(proto, side);
150218191Spjd	if (conn == NULL) {
151219873Spjd		if (proto->prt_close != NULL)
152219873Spjd			proto->prt_close(ctx);
153218191Spjd		errno = ENOMEM;
154218191Spjd		return (-1);
155218191Spjd	}
156204076Spjd	conn->pc_ctx = ctx;
157204076Spjd	*connp = conn;
158218191Spjd
159204076Spjd	return (0);
160204076Spjd}
161204076Spjd
162204076Spjdint
163219818Spjdproto_client(const char *srcaddr, const char *dstaddr,
164219818Spjd    struct proto_conn **connp)
165204076Spjd{
166204076Spjd
167219818Spjd	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
168204076Spjd}
169204076Spjd
170204076Spjdint
171218192Spjdproto_connect(struct proto_conn *conn, int timeout)
172204076Spjd{
173204076Spjd	int ret;
174204076Spjd
175218138Spjd	PJDLOG_ASSERT(conn != NULL);
176218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
177218138Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
178218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
179219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
180218193Spjd	PJDLOG_ASSERT(timeout >= -1);
181204076Spjd
182219873Spjd	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
183204076Spjd	if (ret != 0) {
184204076Spjd		errno = ret;
185204076Spjd		return (-1);
186204076Spjd	}
187204076Spjd
188204076Spjd	return (0);
189204076Spjd}
190204076Spjd
191204076Spjdint
192218193Spjdproto_connect_wait(struct proto_conn *conn, int timeout)
193218193Spjd{
194218193Spjd	int ret;
195218193Spjd
196218193Spjd	PJDLOG_ASSERT(conn != NULL);
197218193Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
198218193Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
199218193Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
200219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
201218193Spjd	PJDLOG_ASSERT(timeout >= 0);
202218193Spjd
203219873Spjd	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
204218193Spjd	if (ret != 0) {
205218193Spjd		errno = ret;
206218193Spjd		return (-1);
207218193Spjd	}
208218193Spjd
209218193Spjd	return (0);
210218193Spjd}
211218193Spjd
212218193Spjdint
213204076Spjdproto_server(const char *addr, struct proto_conn **connp)
214204076Spjd{
215204076Spjd
216219818Spjd	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
217204076Spjd}
218204076Spjd
219204076Spjdint
220204076Spjdproto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
221204076Spjd{
222204076Spjd	struct proto_conn *newconn;
223204076Spjd	int ret;
224204076Spjd
225218138Spjd	PJDLOG_ASSERT(conn != NULL);
226218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
227218138Spjd	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
228218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
229219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
230204076Spjd
231218191Spjd	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
232204076Spjd	if (newconn == NULL)
233204076Spjd		return (-1);
234204076Spjd
235219873Spjd	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
236204076Spjd	if (ret != 0) {
237218191Spjd		proto_free(newconn);
238204076Spjd		errno = ret;
239204076Spjd		return (-1);
240204076Spjd	}
241204076Spjd
242204076Spjd	*newconnp = newconn;
243204076Spjd
244204076Spjd	return (0);
245204076Spjd}
246204076Spjd
247204076Spjdint
248212033Spjdproto_send(const struct proto_conn *conn, const void *data, size_t size)
249204076Spjd{
250204076Spjd	int ret;
251204076Spjd
252218138Spjd	PJDLOG_ASSERT(conn != NULL);
253218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
254218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
255219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
256204076Spjd
257219873Spjd	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
258204076Spjd	if (ret != 0) {
259204076Spjd		errno = ret;
260204076Spjd		return (-1);
261204076Spjd	}
262204076Spjd	return (0);
263204076Spjd}
264204076Spjd
265204076Spjdint
266212033Spjdproto_recv(const struct proto_conn *conn, void *data, size_t size)
267204076Spjd{
268204076Spjd	int ret;
269204076Spjd
270218138Spjd	PJDLOG_ASSERT(conn != NULL);
271218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
272218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
273219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
274204076Spjd
275219873Spjd	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
276204076Spjd	if (ret != 0) {
277204076Spjd		errno = ret;
278204076Spjd		return (-1);
279204076Spjd	}
280204076Spjd	return (0);
281204076Spjd}
282204076Spjd
283204076Spjdint
284218194Spjdproto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
285218139Spjd{
286218194Spjd	const char *protoname;
287218194Spjd	int ret, fd;
288218139Spjd
289218139Spjd	PJDLOG_ASSERT(conn != NULL);
290218139Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
291218139Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
292219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
293218194Spjd	PJDLOG_ASSERT(mconn != NULL);
294218194Spjd	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
295218194Spjd	PJDLOG_ASSERT(mconn->pc_proto != NULL);
296218194Spjd	fd = proto_descriptor(mconn);
297218194Spjd	PJDLOG_ASSERT(fd >= 0);
298219873Spjd	protoname = mconn->pc_proto->prt_name;
299218194Spjd	PJDLOG_ASSERT(protoname != NULL);
300218139Spjd
301260007Strociny	ret = conn->pc_proto->prt_send(conn->pc_ctx,
302260007Strociny	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
303218194Spjd	proto_close(mconn);
304218139Spjd	if (ret != 0) {
305218139Spjd		errno = ret;
306218139Spjd		return (-1);
307218139Spjd	}
308218139Spjd	return (0);
309218139Spjd}
310218139Spjd
311218139Spjdint
312218194Spjdproto_connection_recv(const struct proto_conn *conn, bool client,
313218194Spjd    struct proto_conn **newconnp)
314218139Spjd{
315218194Spjd	char protoname[128];
316219873Spjd	struct proto *proto;
317218194Spjd	struct proto_conn *newconn;
318218194Spjd	int ret, fd;
319218139Spjd
320218139Spjd	PJDLOG_ASSERT(conn != NULL);
321218139Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
322218139Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
323219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
324218194Spjd	PJDLOG_ASSERT(newconnp != NULL);
325218139Spjd
326218194Spjd	bzero(protoname, sizeof(protoname));
327218194Spjd
328260007Strociny	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
329218194Spjd	    sizeof(protoname) - 1, &fd);
330218139Spjd	if (ret != 0) {
331218139Spjd		errno = ret;
332218139Spjd		return (-1);
333218139Spjd	}
334218194Spjd
335218194Spjd	PJDLOG_ASSERT(fd >= 0);
336218194Spjd
337219873Spjd	TAILQ_FOREACH(proto, &protos, prt_next) {
338219873Spjd		if (strcmp(proto->prt_name, protoname) == 0)
339218194Spjd			break;
340218194Spjd	}
341218194Spjd	if (proto == NULL) {
342218194Spjd		errno = EINVAL;
343218194Spjd		return (-1);
344218194Spjd	}
345218194Spjd
346218194Spjd	newconn = proto_alloc(proto,
347218194Spjd	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
348218194Spjd	if (newconn == NULL)
349218194Spjd		return (-1);
350219873Spjd	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
351219873Spjd	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
352218194Spjd	if (ret != 0) {
353218194Spjd		proto_free(newconn);
354218194Spjd		errno = ret;
355218194Spjd		return (-1);
356218194Spjd	}
357218194Spjd
358218194Spjd	*newconnp = newconn;
359218194Spjd
360218139Spjd	return (0);
361218139Spjd}
362218139Spjd
363218139Spjdint
364204076Spjdproto_descriptor(const struct proto_conn *conn)
365204076Spjd{
366204076Spjd
367218138Spjd	PJDLOG_ASSERT(conn != NULL);
368218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
369218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
370219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
371204076Spjd
372219873Spjd	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
373204076Spjd}
374204076Spjd
375204076Spjdbool
376204076Spjdproto_address_match(const struct proto_conn *conn, const char *addr)
377204076Spjd{
378204076Spjd
379218138Spjd	PJDLOG_ASSERT(conn != NULL);
380218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
381218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
382219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
383204076Spjd
384219873Spjd	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
385204076Spjd}
386204076Spjd
387204076Spjdvoid
388204076Spjdproto_local_address(const struct proto_conn *conn, char *addr, size_t size)
389204076Spjd{
390204076Spjd
391218138Spjd	PJDLOG_ASSERT(conn != NULL);
392218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
393218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
394219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
395204076Spjd
396219873Spjd	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
397204076Spjd}
398204076Spjd
399204076Spjdvoid
400204076Spjdproto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
401204076Spjd{
402204076Spjd
403218138Spjd	PJDLOG_ASSERT(conn != NULL);
404218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
405218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
406219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
407204076Spjd
408219873Spjd	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
409204076Spjd}
410204076Spjd
411207371Spjdint
412207371Spjdproto_timeout(const struct proto_conn *conn, int timeout)
413207371Spjd{
414207371Spjd	struct timeval tv;
415207371Spjd	int fd;
416207371Spjd
417218138Spjd	PJDLOG_ASSERT(conn != NULL);
418218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
419218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
420207371Spjd
421207371Spjd	fd = proto_descriptor(conn);
422231017Strociny	if (fd == -1)
423207371Spjd		return (-1);
424207371Spjd
425207371Spjd	tv.tv_sec = timeout;
426207371Spjd	tv.tv_usec = 0;
427231017Strociny	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
428207371Spjd		return (-1);
429231017Strociny	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
430207371Spjd		return (-1);
431207371Spjd
432207371Spjd	return (0);
433207371Spjd}
434207371Spjd
435204076Spjdvoid
436204076Spjdproto_close(struct proto_conn *conn)
437204076Spjd{
438204076Spjd
439218138Spjd	PJDLOG_ASSERT(conn != NULL);
440218138Spjd	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
441218138Spjd	PJDLOG_ASSERT(conn->pc_proto != NULL);
442219873Spjd	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
443204076Spjd
444219873Spjd	conn->pc_proto->prt_close(conn->pc_ctx);
445218191Spjd	proto_free(conn);
446204076Spjd}
447