1/*	$OpenBSD: relay.c,v 1.2 2014/01/08 23:32:17 bluhm Exp $ */
2/*
3 * Copyright (c) 2013 Alexander Bluhm <bluhm@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18/*
19 * Accept tcp or udp from client and connect to server.
20 * Then copy or splice data from client to server.
21 */
22
23#include <sys/types.h>
24#include <sys/socket.h>
25#include <sys/time.h>
26
27#include <netinet/in.h>
28#include <netinet/tcp.h>
29
30#include <errno.h>
31#include <err.h>
32#include <fcntl.h>
33#include <netdb.h>
34#include <stdio.h>
35#include <stdlib.h>
36#include <string.h>
37#include <unistd.h>
38
39#define	BUFSIZE		(1<<16)
40
41__dead void	usage(void);
42void		relay_copy(int, int);
43void		relay_splice(int, int);
44int		socket_listen(int *, struct addrinfo *, const char *,
45		    const char *);
46int		listen_select(const int *, int);
47int		socket_accept(int);
48int		socket_connect(struct addrinfo *, const char *, const char *);
49
50__dead void
51usage(void)
52{
53	fprintf(stderr,
54	    "usage: relay copy|splice [-46tu] [-b bindaddress] listenport "
55	    "hostname port\n"
56	    "       copy [-46tu] [-b bindaddress] listenport hostname port\n"
57	    "       splice [-46tu] [-b bindaddress] listenport hostname port\n"
58	    "           -4              IPv4 only\n"
59	    "           -6              IPv6 only\n"
60	    "           -b bindaddress  bind listen socket to address\n"
61	    "           -t              TCP (default)\n"
62	    "           -u              UDP\n"
63	    );
64	exit(1);
65}
66
67void
68relay_copy(int fdin, int fdout)
69{
70	char buf[BUFSIZE];
71	off_t len;
72	size_t off;
73	ssize_t nr, nw;
74
75	printf("copy...\n");
76	len = 0;
77	while (1) {
78		nr = read(fdin, buf, sizeof(buf));
79		if (nr == -1)
80			err(1, "read");
81		if (nr == 0)
82			break;
83		len += nr;
84		off = 0;
85		do {
86			nw = write(fdout, buf + off, nr);
87			if (nw == -1)
88				err(1, "write");
89			off += nw;
90			nr -= nw;
91		} while (nr);
92	}
93	printf("len %lld\n", len);
94}
95
96void
97relay_splice(int fdin, int fdout)
98{
99	fd_set fdset;
100	socklen_t optlen;
101	off_t len;
102	int error;
103
104	printf("splice...\n");
105	if (setsockopt(fdin, SOL_SOCKET, SO_SPLICE, &fdout, sizeof(int)) == -1)
106		err(1, "setsockopt splice");
107	FD_ZERO(&fdset);
108	FD_SET(fdin, &fdset);
109	if (select(fdin+1, &fdset, NULL, NULL, NULL) == -1)
110		err(1, "select");
111	optlen = sizeof(error);
112	if (getsockopt(fdin, SOL_SOCKET, SO_ERROR, &error, &optlen) == -1)
113		err(1, "getsockopt error");
114	if (error)
115		printf("error %s\n", strerror(error));
116	optlen = sizeof(len);
117	if (getsockopt(fdin, SOL_SOCKET, SO_SPLICE, &len, &optlen) == -1)
118		err(1, "getsockopt splice");
119	printf("len %lld\n", len);
120}
121
122int
123socket_listen(int *ls, struct addrinfo *hints, const char *listenaddr,
124    const char *listenport)
125{
126	char host[NI_MAXHOST], serv[NI_MAXSERV];
127	struct sockaddr_storage sa;
128	socklen_t salen;
129	struct addrinfo *res, *res0;
130	const char *cause = NULL;
131	int optval, error, save_errno, nls;
132
133	hints->ai_flags = AI_PASSIVE;
134	error = getaddrinfo(listenaddr, listenport, hints, &res0);
135	if (error)
136		errx(1, "getaddrinfo %s %s: %s", listenaddr == NULL ? "*" :
137		    listenaddr, listenport, gai_strerror(error));
138	for (res = res0, nls = 0; res && nls < FD_SETSIZE; res = res->ai_next) {
139		ls[nls] = socket(res->ai_family, res->ai_socktype,
140		    res->ai_protocol);
141		if (ls[nls] == -1) {
142			cause = "listen socket";
143			continue;
144		}
145		optval = 100000;
146		if (setsockopt(ls[nls], SOL_SOCKET, SO_RCVBUF,
147		    &optval, sizeof(optval)) == -1)
148			err(1, "setsockopt rcvbuf");
149		optval = 1;
150		if (setsockopt(ls[nls], SOL_SOCKET, SO_REUSEADDR,
151		    &optval, sizeof(optval)) == -1)
152			err(1, "setsockopt reuseaddr");
153		if (bind(ls[nls], res->ai_addr, res->ai_addrlen) == -1) {
154			cause = "bind";
155			save_errno = errno;
156			close(ls[nls]);
157			errno = save_errno;
158			continue;
159		}
160		if (hints->ai_socktype == SOCK_STREAM) {
161			if (listen(ls[nls], 5) == -1)
162				err(1, "listen");
163		}
164		salen = sizeof(sa);
165		if (getsockname(ls[nls], (struct sockaddr *)&sa, &salen) == -1)
166			err(1, "listen getsockname");
167		error = getnameinfo((struct sockaddr *)&sa, salen,
168		    host, sizeof(host), serv, sizeof(serv),
169		    NI_NUMERICHOST|NI_NUMERICSERV);
170		if (error)
171			errx(1, "listen getnameinfo: %s", gai_strerror(error));
172		printf("listen %s %s\n", host, serv);
173		nls++;
174	}
175	if (nls == 0)
176		err(1, "%s", cause);
177	freeaddrinfo(res0);
178
179	return nls;
180}
181
182int
183listen_select(const int *ls, int nls)
184{
185	fd_set fdset;
186	int i, mfd;
187
188	FD_ZERO(&fdset);
189	mfd = 0;
190	for (i = 0; i < nls; i++) {
191		FD_SET(ls[i], &fdset);
192		if (ls[i] > mfd)
193			mfd = ls[i];
194	}
195	if (select(mfd+1, &fdset, NULL, NULL, NULL) == -1)
196		err(1, "select");
197	for (i = 0; i < nls; i++) {
198		if (FD_ISSET(ls[i], &fdset))
199			break;
200	}
201	if (i == nls)
202		errx(1, "select: no fd set");
203	return ls[i];
204}
205
206int
207socket_accept(int ls)
208{
209	char host[NI_MAXHOST], serv[NI_MAXSERV];
210	struct sockaddr_storage sa;
211	socklen_t salen;
212	int error, as;
213
214	salen = sizeof(sa);
215	as = accept(ls, (struct sockaddr *)&sa, &salen);
216	if (as == -1)
217		err(1, "accept");
218	error = getnameinfo((struct sockaddr *)&sa, salen,
219	    host, sizeof(host), serv, sizeof(serv),
220	    NI_NUMERICHOST|NI_NUMERICSERV);
221	if (error)
222		errx(1, "accept getnameinfo: %s", gai_strerror(error));
223	printf("accept %s %s\n", host, serv);
224
225	return as;
226}
227
228int
229socket_connect(struct addrinfo *hints, const char *hostname, const char *port)
230{
231	char host[NI_MAXHOST], serv[NI_MAXSERV];
232	struct sockaddr_storage sa;
233	socklen_t salen;
234	struct addrinfo *res, *res0;
235	const char *cause = NULL;
236	int optval, error, save_errno, cs;
237
238	hints->ai_flags = 0;
239	error = getaddrinfo(hostname, port, hints, &res0);
240	if (error)
241		errx(1, "getaddrinfo %s %s: %s", hostname, port,
242		    gai_strerror(error));
243	cs = -1;
244	for (res = res0; res; res = res->ai_next) {
245		cs = socket(res->ai_family, res->ai_socktype,
246		    res->ai_protocol);
247		if (cs == -1) {
248			cause = "connect socket";
249			continue;
250		}
251		optval = 100000;
252		if (setsockopt(cs, SOL_SOCKET, SO_SNDBUF,
253		    &optval, sizeof(optval)) == -1)
254			err(1, "setsockopt sndbuf");
255		if (connect(cs, res->ai_addr, res->ai_addrlen) == -1) {
256			cause = "connect";
257			save_errno = errno;
258			close(cs);
259			errno = save_errno;
260			cs = -1;
261			continue;
262		}
263		break;
264	}
265	if (cs == -1)
266		err(1, "%s", cause);
267	salen = sizeof(sa);
268	if (getpeername(cs, (struct sockaddr *)&sa, &salen) == -1)
269		err(1, "connect getpeername");
270	error = getnameinfo((struct sockaddr *)&sa, salen,
271	    host, sizeof(host), serv, sizeof(serv),
272	    NI_NUMERICHOST|NI_NUMERICSERV);
273	if (error)
274		errx(1, "connect getnameinfo: %s", gai_strerror(error));
275	printf("connect %s %s\n", host, serv);
276	freeaddrinfo(res0);
277
278	return cs;
279}
280
281int
282main(int argc, char *argv[])
283{
284	struct addrinfo hints;
285	int ch, ls[FD_SETSIZE], nls, as, cs, optval;
286	const char *listenaddr, *listenport, *hostname, *port;
287	const char *relayname;
288	void (*relayfunc)(int, int);
289
290	relayname = strrchr(argv[0], '/');
291	relayname = relayname ? relayname + 1 : argv[0];
292	if (strcmp(relayname, "copy") == 0)
293		relayfunc = relay_copy;
294	else if (strcmp(relayname, "splice") == 0)
295		relayfunc = relay_splice;
296	else {
297		argc--;
298		argv++;
299		if (argv[0] == NULL)
300			usage();
301		relayname = argv[0];
302		if (strcmp(relayname, "copy") == 0)
303			relayfunc = relay_copy;
304		else if (strcmp(relayname, "splice") == 0)
305			relayfunc = relay_splice;
306		else
307			usage();
308	}
309
310	memset(&hints, 0, sizeof(hints));
311	hints.ai_family = PF_UNSPEC;
312	hints.ai_socktype = SOCK_STREAM;
313	listenaddr = NULL;
314	while ((ch = getopt(argc, argv, "46b:tu")) != -1) {
315		switch (ch) {
316		case '4':
317			hints.ai_family = PF_INET;
318			break;
319		case '6':
320			hints.ai_family = PF_INET6;
321			break;
322		case 'b':
323			listenaddr = optarg;
324			break;
325		case 't':
326			hints.ai_socktype = SOCK_STREAM;
327			break;
328		case 'u':
329			hints.ai_socktype = SOCK_DGRAM;
330			break;
331		default:
332			usage();
333		}
334	}
335	argc -= optind;
336	argv += optind;
337	if (argc != 3)
338		usage();
339	listenport = argv[0];
340	hostname = argv[1];
341	port = argv[2];
342
343	nls = socket_listen(ls, &hints, listenaddr, listenport);
344
345	while (1) {
346		if (hints.ai_socktype == SOCK_STREAM) {
347			as = socket_accept(listen_select(ls, nls));
348			cs = socket_connect(&hints, hostname, port);
349			optval = 1;
350			if (setsockopt(cs, IPPROTO_TCP, TCP_NODELAY,
351			    &optval, sizeof(optval)) == -1)
352				err(1, "setsockopt nodelay");
353		} else {
354			cs = socket_connect(&hints, hostname, port);
355			as = listen_select(ls, nls);
356		}
357
358		relayfunc(as, cs);
359
360		if (close(cs) == -1)
361			err(1, "connect close");
362		if (hints.ai_socktype == SOCK_STREAM) {
363			if (close(as) == -1)
364				err(1, "accept close");
365		}
366		printf("close\n");
367	}
368}
369