proto_tcp.c revision 229945
1/*-
2 * Copyright (c) 2009-2010 The FreeBSD Foundation
3 * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
4 * All rights reserved.
5 *
6 * This software was developed by Pawel Jakub Dawidek under sponsorship from
7 * the FreeBSD Foundation.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 * 1. Redistributions of source code must retain the above copyright
13 *    notice, this list of conditions and the following disclaimer.
14 * 2. Redistributions in binary form must reproduce the above copyright
15 *    notice, this list of conditions and the following disclaimer in the
16 *    documentation and/or other materials provided with the distribution.
17 *
18 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
19 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
22 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28 * SUCH DAMAGE.
29 */
30
31#include <sys/cdefs.h>
32__FBSDID("$FreeBSD: head/sbin/hastd/proto_tcp.c 229945 2012-01-10 22:39:07Z pjd $");
33
34#include <sys/param.h>	/* MAXHOSTNAMELEN */
35#include <sys/socket.h>
36
37#include <arpa/inet.h>
38
39#include <netinet/in.h>
40#include <netinet/tcp.h>
41
42#include <errno.h>
43#include <fcntl.h>
44#include <netdb.h>
45#include <stdbool.h>
46#include <stdint.h>
47#include <stdio.h>
48#include <string.h>
49#include <unistd.h>
50
51#include "pjdlog.h"
52#include "proto_impl.h"
53#include "subr.h"
54
55#define	TCP_CTX_MAGIC	0x7c41c
56struct tcp_ctx {
57	int			tc_magic;
58	struct sockaddr_storage	tc_sa;
59	int			tc_fd;
60	int			tc_side;
61#define	TCP_SIDE_CLIENT		0
62#define	TCP_SIDE_SERVER_LISTEN	1
63#define	TCP_SIDE_SERVER_WORK	2
64};
65
66static int tcp_connect_wait(void *ctx, int timeout);
67static void tcp_close(void *ctx);
68
69/*
70 * Function converts the given string to unsigned number.
71 */
72static int
73numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
74{
75	intmax_t digit, num;
76
77	if (str[0] == '\0')
78		goto invalid;	/* Empty string. */
79	num = 0;
80	for (; *str != '\0'; str++) {
81		if (*str < '0' || *str > '9')
82			goto invalid;	/* Non-digit character. */
83		digit = *str - '0';
84		if (num > num * 10 + digit)
85			goto invalid;	/* Overflow. */
86		num = num * 10 + digit;
87		if (num > maxnum)
88			goto invalid;	/* Too big. */
89	}
90	if (num < minnum)
91		goto invalid;	/* Too small. */
92	*nump = num;
93	return (0);
94invalid:
95	errno = EINVAL;
96	return (-1);
97}
98
99static int
100tcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
101{
102	char iporhost[MAXHOSTNAMELEN], portstr[6];
103	struct addrinfo hints;
104	struct addrinfo *res;
105	const char *pp;
106	intmax_t port;
107	size_t size;
108	int error;
109
110	if (addr == NULL)
111		return (-1);
112
113	bzero(&hints, sizeof(hints));
114	hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
115	hints.ai_family = PF_UNSPEC;
116	hints.ai_socktype = SOCK_STREAM;
117	hints.ai_protocol = IPPROTO_TCP;
118
119	if (strncasecmp(addr, "tcp4://", 7) == 0) {
120		addr += 7;
121		hints.ai_family = PF_INET;
122	} else if (strncasecmp(addr, "tcp6://", 7) == 0) {
123		addr += 7;
124		hints.ai_family = PF_INET6;
125	} else if (strncasecmp(addr, "tcp://", 6) == 0) {
126		addr += 6;
127	} else {
128		/*
129		 * Because TCP is the default assume IP or host is given without
130		 * prefix.
131		 */
132	}
133
134	/*
135	 * Extract optional port.
136	 * There are three cases to consider.
137	 * 1. hostname with port, eg. freefall.freebsd.org:8457
138	 * 2. IPv4 address with port, eg. 192.168.0.101:8457
139	 * 3. IPv6 address with port, eg. [fe80::1]:8457
140	 * We discover IPv6 address by checking for two colons and if port is
141	 * given, the address has to start with [.
142	 */
143	pp = NULL;
144	if (strchr(addr, ':') != strrchr(addr, ':')) {
145		if (addr[0] == '[')
146			pp = strrchr(addr, ':');
147	} else {
148		pp = strrchr(addr, ':');
149	}
150	if (pp == NULL) {
151		/* Port not given, use the default. */
152		port = defport;
153	} else {
154		if (numfromstr(pp + 1, 1, 65535, &port) == -1)
155			return (errno);
156	}
157	(void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
158	/* Extract host name or IP address. */
159	if (pp == NULL) {
160		size = sizeof(iporhost);
161		if (strlcpy(iporhost, addr, size) >= size)
162			return (ENAMETOOLONG);
163	} else if (addr[0] == '[' && pp[-1] == ']') {
164		size = (size_t)(pp - addr - 2 + 1);
165		if (size > sizeof(iporhost))
166			return (ENAMETOOLONG);
167		(void)strlcpy(iporhost, addr + 1, size);
168	} else {
169		size = (size_t)(pp - addr + 1);
170		if (size > sizeof(iporhost))
171			return (ENAMETOOLONG);
172		(void)strlcpy(iporhost, addr, size);
173	}
174
175	error = getaddrinfo(iporhost, portstr, &hints, &res);
176	if (error != 0) {
177		pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
178		    portstr, gai_strerror(error));
179		return (EINVAL);
180	}
181	if (res == NULL)
182		return (ENOENT);
183
184	memcpy(sap, res->ai_addr, res->ai_addrlen);
185
186	freeaddrinfo(res);
187
188	return (0);
189}
190
191static int
192tcp_setup_new(const char *addr, int side, void **ctxp)
193{
194	struct tcp_ctx *tctx;
195	int ret, nodelay;
196
197	PJDLOG_ASSERT(addr != NULL);
198	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
199	    side == TCP_SIDE_SERVER_LISTEN);
200	PJDLOG_ASSERT(ctxp != NULL);
201
202	tctx = malloc(sizeof(*tctx));
203	if (tctx == NULL)
204		return (errno);
205
206	/* Parse given address. */
207	if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
208		free(tctx);
209		return (ret);
210	}
211
212	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
213
214	tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
215	if (tctx->tc_fd == -1) {
216		ret = errno;
217		free(tctx);
218		return (ret);
219	}
220
221	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
222
223	/* Socket settings. */
224	nodelay = 1;
225	if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
226	    sizeof(nodelay)) == -1) {
227		pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
228	}
229
230	tctx->tc_side = side;
231	tctx->tc_magic = TCP_CTX_MAGIC;
232	*ctxp = tctx;
233
234	return (0);
235}
236
237static int
238tcp_setup_wrap(int fd, int side, void **ctxp)
239{
240	struct tcp_ctx *tctx;
241
242	PJDLOG_ASSERT(fd >= 0);
243	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
244	    side == TCP_SIDE_SERVER_WORK);
245	PJDLOG_ASSERT(ctxp != NULL);
246
247	tctx = malloc(sizeof(*tctx));
248	if (tctx == NULL)
249		return (errno);
250
251	tctx->tc_fd = fd;
252	tctx->tc_sa.ss_family = AF_UNSPEC;
253	tctx->tc_side = side;
254	tctx->tc_magic = TCP_CTX_MAGIC;
255	*ctxp = tctx;
256
257	return (0);
258}
259
260static int
261tcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
262{
263	struct tcp_ctx *tctx;
264	struct sockaddr_storage sa;
265	int ret;
266
267	ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
268	if (ret != 0)
269		return (ret);
270	tctx = *ctxp;
271	if (srcaddr == NULL)
272		return (0);
273	ret = tcp_addr(srcaddr, 0, &sa);
274	if (ret != 0) {
275		tcp_close(tctx);
276		return (ret);
277	}
278	if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) == -1) {
279		ret = errno;
280		tcp_close(tctx);
281		return (ret);
282	}
283	return (0);
284}
285
286static int
287tcp_connect(void *ctx, int timeout)
288{
289	struct tcp_ctx *tctx = ctx;
290	int error, flags;
291
292	PJDLOG_ASSERT(tctx != NULL);
293	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
294	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
295	PJDLOG_ASSERT(tctx->tc_fd >= 0);
296	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
297	PJDLOG_ASSERT(timeout >= -1);
298
299	flags = fcntl(tctx->tc_fd, F_GETFL);
300	if (flags == -1) {
301		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
302		return (errno);
303	}
304	/*
305	 * We make socket non-blocking so we can handle connection timeout
306	 * manually.
307	 */
308	flags |= O_NONBLOCK;
309	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
310		pjdlog_common(LOG_DEBUG, 1, errno,
311		    "fcntl(F_SETFL, O_NONBLOCK) failed");
312		return (errno);
313	}
314
315	if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
316	    tctx->tc_sa.ss_len) == 0) {
317		if (timeout == -1)
318			return (0);
319		error = 0;
320		goto done;
321	}
322	if (errno != EINPROGRESS) {
323		error = errno;
324		pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
325		goto done;
326	}
327	if (timeout == -1)
328		return (0);
329	return (tcp_connect_wait(ctx, timeout));
330done:
331	flags &= ~O_NONBLOCK;
332	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
333		if (error == 0)
334			error = errno;
335		pjdlog_common(LOG_DEBUG, 1, errno,
336		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
337	}
338	return (error);
339}
340
341static int
342tcp_connect_wait(void *ctx, int timeout)
343{
344	struct tcp_ctx *tctx = ctx;
345	struct timeval tv;
346	fd_set fdset;
347	socklen_t esize;
348	int error, flags, ret;
349
350	PJDLOG_ASSERT(tctx != NULL);
351	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
352	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
353	PJDLOG_ASSERT(tctx->tc_fd >= 0);
354	PJDLOG_ASSERT(timeout >= 0);
355
356	tv.tv_sec = timeout;
357	tv.tv_usec = 0;
358again:
359	FD_ZERO(&fdset);
360	FD_SET(tctx->tc_fd, &fdset);
361	ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
362	if (ret == 0) {
363		error = ETIMEDOUT;
364		goto done;
365	} else if (ret == -1) {
366		if (errno == EINTR)
367			goto again;
368		error = errno;
369		pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
370		goto done;
371	}
372	PJDLOG_ASSERT(ret > 0);
373	PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
374	esize = sizeof(error);
375	if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
376	    &esize) == -1) {
377		error = errno;
378		pjdlog_common(LOG_DEBUG, 1, errno,
379		    "getsockopt(SO_ERROR) failed");
380		goto done;
381	}
382	if (error != 0) {
383		pjdlog_common(LOG_DEBUG, 1, error,
384		    "getsockopt(SO_ERROR) returned error");
385		goto done;
386	}
387	error = 0;
388done:
389	flags = fcntl(tctx->tc_fd, F_GETFL);
390	if (flags == -1) {
391		if (error == 0)
392			error = errno;
393		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
394		return (error);
395	}
396	flags &= ~O_NONBLOCK;
397	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
398		if (error == 0)
399			error = errno;
400		pjdlog_common(LOG_DEBUG, 1, errno,
401		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
402	}
403	return (error);
404}
405
406static int
407tcp_server(const char *addr, void **ctxp)
408{
409	struct tcp_ctx *tctx;
410	int ret, val;
411
412	ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
413	if (ret != 0)
414		return (ret);
415
416	tctx = *ctxp;
417
418	val = 1;
419	/* Ignore failure. */
420	(void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
421	   sizeof(val));
422
423	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
424
425	if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
426	    tctx->tc_sa.ss_len) == -1) {
427		ret = errno;
428		tcp_close(tctx);
429		return (ret);
430	}
431	if (listen(tctx->tc_fd, 8) == -1) {
432		ret = errno;
433		tcp_close(tctx);
434		return (ret);
435	}
436
437	return (0);
438}
439
440static int
441tcp_accept(void *ctx, void **newctxp)
442{
443	struct tcp_ctx *tctx = ctx;
444	struct tcp_ctx *newtctx;
445	socklen_t fromlen;
446	int ret;
447
448	PJDLOG_ASSERT(tctx != NULL);
449	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
450	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
451	PJDLOG_ASSERT(tctx->tc_fd >= 0);
452	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
453
454	newtctx = malloc(sizeof(*newtctx));
455	if (newtctx == NULL)
456		return (errno);
457
458	fromlen = tctx->tc_sa.ss_len;
459	newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
460	    &fromlen);
461	if (newtctx->tc_fd == -1) {
462		ret = errno;
463		free(newtctx);
464		return (ret);
465	}
466
467	newtctx->tc_side = TCP_SIDE_SERVER_WORK;
468	newtctx->tc_magic = TCP_CTX_MAGIC;
469	*newctxp = newtctx;
470
471	return (0);
472}
473
474static int
475tcp_wrap(int fd, bool client, void **ctxp)
476{
477
478	return (tcp_setup_wrap(fd,
479	    client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
480}
481
482static int
483tcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
484{
485	struct tcp_ctx *tctx = ctx;
486
487	PJDLOG_ASSERT(tctx != NULL);
488	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
489	PJDLOG_ASSERT(tctx->tc_fd >= 0);
490	PJDLOG_ASSERT(fd == -1);
491
492	return (proto_common_send(tctx->tc_fd, data, size, -1));
493}
494
495static int
496tcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
497{
498	struct tcp_ctx *tctx = ctx;
499
500	PJDLOG_ASSERT(tctx != NULL);
501	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
502	PJDLOG_ASSERT(tctx->tc_fd >= 0);
503	PJDLOG_ASSERT(fdp == NULL);
504
505	return (proto_common_recv(tctx->tc_fd, data, size, NULL));
506}
507
508static int
509tcp_descriptor(const void *ctx)
510{
511	const struct tcp_ctx *tctx = ctx;
512
513	PJDLOG_ASSERT(tctx != NULL);
514	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
515
516	return (tctx->tc_fd);
517}
518
519static bool
520tcp_address_match(const void *ctx, const char *addr)
521{
522	const struct tcp_ctx *tctx = ctx;
523	struct sockaddr_storage sa1, sa2;
524	socklen_t salen;
525
526	PJDLOG_ASSERT(tctx != NULL);
527	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
528
529	if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
530		return (false);
531
532	salen = sizeof(sa2);
533	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) == -1)
534		return (false);
535
536	if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
537		return (false);
538
539	switch (sa1.ss_family) {
540	case AF_INET:
541	    {
542		struct sockaddr_in *sin1, *sin2;
543
544		sin1 = (struct sockaddr_in *)&sa1;
545		sin2 = (struct sockaddr_in *)&sa2;
546
547		return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
548		    sizeof(sin1->sin_addr)) == 0);
549	    }
550	case AF_INET6:
551	    {
552		struct sockaddr_in6 *sin1, *sin2;
553
554		sin1 = (struct sockaddr_in6 *)&sa1;
555		sin2 = (struct sockaddr_in6 *)&sa2;
556
557		return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
558		    sizeof(sin1->sin6_addr)) == 0);
559	    }
560	default:
561		return (false);
562	}
563}
564
565static void
566tcp_local_address(const void *ctx, char *addr, size_t size)
567{
568	const struct tcp_ctx *tctx = ctx;
569	struct sockaddr_storage sa;
570	socklen_t salen;
571
572	PJDLOG_ASSERT(tctx != NULL);
573	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
574
575	salen = sizeof(sa);
576	if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
577		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
578		return;
579	}
580	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
581}
582
583static void
584tcp_remote_address(const void *ctx, char *addr, size_t size)
585{
586	const struct tcp_ctx *tctx = ctx;
587	struct sockaddr_storage sa;
588	socklen_t salen;
589
590	PJDLOG_ASSERT(tctx != NULL);
591	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
592
593	salen = sizeof(sa);
594	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
595		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
596		return;
597	}
598	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
599}
600
601static void
602tcp_close(void *ctx)
603{
604	struct tcp_ctx *tctx = ctx;
605
606	PJDLOG_ASSERT(tctx != NULL);
607	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
608
609	if (tctx->tc_fd >= 0)
610		close(tctx->tc_fd);
611	tctx->tc_magic = 0;
612	free(tctx);
613}
614
615static struct proto tcp_proto = {
616	.prt_name = "tcp",
617	.prt_client = tcp_client,
618	.prt_connect = tcp_connect,
619	.prt_connect_wait = tcp_connect_wait,
620	.prt_server = tcp_server,
621	.prt_accept = tcp_accept,
622	.prt_wrap = tcp_wrap,
623	.prt_send = tcp_send,
624	.prt_recv = tcp_recv,
625	.prt_descriptor = tcp_descriptor,
626	.prt_address_match = tcp_address_match,
627	.prt_local_address = tcp_local_address,
628	.prt_remote_address = tcp_remote_address,
629	.prt_close = tcp_close
630};
631
632static __constructor void
633tcp_ctor(void)
634{
635
636	proto_register(&tcp_proto, true);
637}
638