proto_tcp.c revision 222118
1/*-
2 * Copyright (c) 2009-2010 The FreeBSD Foundation
3 * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
4 * All rights reserved.
5 *
6 * This software was developed by Pawel Jakub Dawidek under sponsorship from
7 * the FreeBSD Foundation.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 * 1. Redistributions of source code must retain the above copyright
13 *    notice, this list of conditions and the following disclaimer.
14 * 2. Redistributions in binary form must reproduce the above copyright
15 *    notice, this list of conditions and the following disclaimer in the
16 *    documentation and/or other materials provided with the distribution.
17 *
18 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
19 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
22 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28 * SUCH DAMAGE.
29 */
30
31#include <sys/cdefs.h>
32__FBSDID("$FreeBSD: head/sbin/hastd/proto_tcp.c 222118 2011-05-20 11:14:05Z pjd $");
33
34#include <sys/param.h>	/* MAXHOSTNAMELEN */
35#include <sys/socket.h>
36
37#include <arpa/inet.h>
38
39#include <netinet/in.h>
40#include <netinet/tcp.h>
41
42#include <errno.h>
43#include <fcntl.h>
44#include <netdb.h>
45#include <stdbool.h>
46#include <stdint.h>
47#include <stdio.h>
48#include <string.h>
49#include <unistd.h>
50
51#include "pjdlog.h"
52#include "proto_impl.h"
53#include "subr.h"
54
55#define	TCP_CTX_MAGIC	0x7c41c
56struct tcp_ctx {
57	int			tc_magic;
58	struct sockaddr_storage	tc_sa;
59	int			tc_fd;
60	int			tc_side;
61#define	TCP_SIDE_CLIENT		0
62#define	TCP_SIDE_SERVER_LISTEN	1
63#define	TCP_SIDE_SERVER_WORK	2
64};
65
66static int tcp_connect_wait(void *ctx, int timeout);
67static void tcp_close(void *ctx);
68
69/*
70 * Function converts the given string to unsigned number.
71 */
72static int
73numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
74{
75	intmax_t digit, num;
76
77	if (str[0] == '\0')
78		goto invalid;	/* Empty string. */
79	num = 0;
80	for (; *str != '\0'; str++) {
81		if (*str < '0' || *str > '9')
82			goto invalid;	/* Non-digit character. */
83		digit = *str - '0';
84		if (num > num * 10 + digit)
85			goto invalid;	/* Overflow. */
86		num = num * 10 + digit;
87		if (num > maxnum)
88			goto invalid;	/* Too big. */
89	}
90	if (num < minnum)
91		goto invalid;	/* Too small. */
92	*nump = num;
93	return (0);
94invalid:
95	errno = EINVAL;
96	return (-1);
97}
98
99static int
100tcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
101{
102	char iporhost[MAXHOSTNAMELEN], portstr[6];
103	struct addrinfo hints;
104	struct addrinfo *res;
105	const char *pp;
106	intmax_t port;
107	size_t size;
108	int error;
109
110	if (addr == NULL)
111		return (-1);
112
113	bzero(&hints, sizeof(hints));
114	hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
115	hints.ai_family = PF_UNSPEC;
116	hints.ai_socktype = SOCK_STREAM;
117	hints.ai_protocol = IPPROTO_TCP;
118
119	if (strncasecmp(addr, "tcp4://", 7) == 0) {
120		addr += 7;
121		hints.ai_family = PF_INET;
122	} else if (strncasecmp(addr, "tcp6://", 7) == 0) {
123		addr += 7;
124		hints.ai_family = PF_INET6;
125	} else if (strncasecmp(addr, "tcp://", 6) == 0) {
126		addr += 6;
127	} else {
128		/*
129		 * Because TCP is the default assume IP or host is given without
130		 * prefix.
131		 */
132	}
133
134	/*
135	 * Extract optional port.
136	 * There are three cases to consider.
137	 * 1. hostname with port, eg. freefall.freebsd.org:8457
138	 * 2. IPv4 address with port, eg. 192.168.0.101:8457
139	 * 3. IPv6 address with port, eg. [fe80::1]:8457
140	 * We discover IPv6 address by checking for two colons and if port is
141	 * given, the address has to start with [.
142	 */
143	pp = NULL;
144	if (strchr(addr, ':') != strrchr(addr, ':')) {
145		if (addr[0] == '[')
146			pp = strrchr(addr, ':');
147	} else {
148		pp = strrchr(addr, ':');
149	}
150	if (pp == NULL) {
151		/* Port not given, use the default. */
152		port = defport;
153	} else {
154		if (numfromstr(pp + 1, 1, 65535, &port) < 0)
155			return (errno);
156	}
157	(void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
158	/* Extract host name or IP address. */
159	if (pp == NULL) {
160		size = sizeof(iporhost);
161		if (strlcpy(iporhost, addr, size) >= size)
162			return (ENAMETOOLONG);
163	} else if (addr[0] == '[' && pp[-1] == ']') {
164		size = (size_t)(pp - addr - 2 + 1);
165		if (size > sizeof(iporhost))
166			return (ENAMETOOLONG);
167		(void)strlcpy(iporhost, addr + 1, size);
168	} else {
169		size = (size_t)(pp - addr + 1);
170		if (size > sizeof(iporhost))
171			return (ENAMETOOLONG);
172		(void)strlcpy(iporhost, addr, size);
173	}
174
175	error = getaddrinfo(iporhost, portstr, &hints, &res);
176	if (error != 0) {
177		pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
178		    portstr, gai_strerror(error));
179		return (EINVAL);
180	}
181	if (res == NULL)
182		return (ENOENT);
183
184	memcpy(sap, res->ai_addr, res->ai_addrlen);
185
186	freeaddrinfo(res);
187
188	return (0);
189}
190
191static int
192tcp_setup_new(const char *addr, int side, void **ctxp)
193{
194	struct tcp_ctx *tctx;
195	int ret, nodelay;
196
197	PJDLOG_ASSERT(addr != NULL);
198	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
199	    side == TCP_SIDE_SERVER_LISTEN);
200	PJDLOG_ASSERT(ctxp != NULL);
201
202	tctx = malloc(sizeof(*tctx));
203	if (tctx == NULL)
204		return (errno);
205
206	/* Parse given address. */
207	if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
208		free(tctx);
209		return (ret);
210	}
211
212	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
213
214	tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
215	if (tctx->tc_fd == -1) {
216		ret = errno;
217		free(tctx);
218		return (ret);
219	}
220
221	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
222
223	/* Socket settings. */
224	nodelay = 1;
225	if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
226	    sizeof(nodelay)) == -1) {
227		pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
228	}
229
230	tctx->tc_side = side;
231	tctx->tc_magic = TCP_CTX_MAGIC;
232	*ctxp = tctx;
233
234	return (0);
235}
236
237static int
238tcp_setup_wrap(int fd, int side, void **ctxp)
239{
240	struct tcp_ctx *tctx;
241
242	PJDLOG_ASSERT(fd >= 0);
243	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
244	    side == TCP_SIDE_SERVER_WORK);
245	PJDLOG_ASSERT(ctxp != NULL);
246
247	tctx = malloc(sizeof(*tctx));
248	if (tctx == NULL)
249		return (errno);
250
251	tctx->tc_fd = fd;
252	tctx->tc_sa.ss_family = AF_UNSPEC;
253	tctx->tc_side = side;
254	tctx->tc_magic = TCP_CTX_MAGIC;
255	*ctxp = tctx;
256
257	return (0);
258}
259
260static int
261tcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
262{
263	struct tcp_ctx *tctx;
264	struct sockaddr_storage sa;
265	int ret;
266
267	ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
268	if (ret != 0)
269		return (ret);
270	tctx = *ctxp;
271	if (srcaddr == NULL)
272		return (0);
273	ret = tcp_addr(srcaddr, 0, &sa);
274	if (ret != 0) {
275		tcp_close(tctx);
276		return (ret);
277	}
278	if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) < 0) {
279		ret = errno;
280		tcp_close(tctx);
281		return (ret);
282	}
283	return (0);
284}
285
286static int
287tcp_connect(void *ctx, int timeout)
288{
289	struct tcp_ctx *tctx = ctx;
290	int error, flags;
291
292	PJDLOG_ASSERT(tctx != NULL);
293	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
294	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
295	PJDLOG_ASSERT(tctx->tc_fd >= 0);
296	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
297	PJDLOG_ASSERT(timeout >= -1);
298
299	flags = fcntl(tctx->tc_fd, F_GETFL);
300	if (flags == -1) {
301		KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
302		    "fcntl(F_GETFL) failed"));
303		return (errno);
304	}
305	/*
306	 * We make socket non-blocking so we can handle connection timeout
307	 * manually.
308	 */
309	flags |= O_NONBLOCK;
310	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
311		KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
312		    "fcntl(F_SETFL, O_NONBLOCK) failed"));
313		return (errno);
314	}
315
316	if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
317	    tctx->tc_sa.ss_len) == 0) {
318		if (timeout == -1)
319			return (0);
320		error = 0;
321		goto done;
322	}
323	if (errno != EINPROGRESS) {
324		error = errno;
325		pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
326		goto done;
327	}
328	if (timeout == -1)
329		return (0);
330	return (tcp_connect_wait(ctx, timeout));
331done:
332	flags &= ~O_NONBLOCK;
333	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
334		if (error == 0)
335			error = errno;
336		pjdlog_common(LOG_DEBUG, 1, errno,
337		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
338	}
339	return (error);
340}
341
342static int
343tcp_connect_wait(void *ctx, int timeout)
344{
345	struct tcp_ctx *tctx = ctx;
346	struct timeval tv;
347	fd_set fdset;
348	socklen_t esize;
349	int error, flags, ret;
350
351	PJDLOG_ASSERT(tctx != NULL);
352	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
353	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
354	PJDLOG_ASSERT(tctx->tc_fd >= 0);
355	PJDLOG_ASSERT(timeout >= 0);
356
357	tv.tv_sec = timeout;
358	tv.tv_usec = 0;
359again:
360	FD_ZERO(&fdset);
361	FD_SET(tctx->tc_fd, &fdset);
362	ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
363	if (ret == 0) {
364		error = ETIMEDOUT;
365		goto done;
366	} else if (ret == -1) {
367		if (errno == EINTR)
368			goto again;
369		error = errno;
370		pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
371		goto done;
372	}
373	PJDLOG_ASSERT(ret > 0);
374	PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
375	esize = sizeof(error);
376	if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
377	    &esize) == -1) {
378		error = errno;
379		pjdlog_common(LOG_DEBUG, 1, errno,
380		    "getsockopt(SO_ERROR) failed");
381		goto done;
382	}
383	if (error != 0) {
384		pjdlog_common(LOG_DEBUG, 1, error,
385		    "getsockopt(SO_ERROR) returned error");
386		goto done;
387	}
388	error = 0;
389done:
390	flags = fcntl(tctx->tc_fd, F_GETFL);
391	if (flags == -1) {
392		if (error == 0)
393			error = errno;
394		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
395		return (error);
396	}
397	flags &= ~O_NONBLOCK;
398	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
399		if (error == 0)
400			error = errno;
401		pjdlog_common(LOG_DEBUG, 1, errno,
402		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
403	}
404	return (error);
405}
406
407static int
408tcp_server(const char *addr, void **ctxp)
409{
410	struct tcp_ctx *tctx;
411	int ret, val;
412
413	ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
414	if (ret != 0)
415		return (ret);
416
417	tctx = *ctxp;
418
419	val = 1;
420	/* Ignore failure. */
421	(void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
422	   sizeof(val));
423
424	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
425
426	if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
427	    tctx->tc_sa.ss_len) < 0) {
428		ret = errno;
429		tcp_close(tctx);
430		return (ret);
431	}
432	if (listen(tctx->tc_fd, 8) < 0) {
433		ret = errno;
434		tcp_close(tctx);
435		return (ret);
436	}
437
438	return (0);
439}
440
441static int
442tcp_accept(void *ctx, void **newctxp)
443{
444	struct tcp_ctx *tctx = ctx;
445	struct tcp_ctx *newtctx;
446	socklen_t fromlen;
447	int ret;
448
449	PJDLOG_ASSERT(tctx != NULL);
450	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
451	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
452	PJDLOG_ASSERT(tctx->tc_fd >= 0);
453	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
454
455	newtctx = malloc(sizeof(*newtctx));
456	if (newtctx == NULL)
457		return (errno);
458
459	fromlen = tctx->tc_sa.ss_len;
460	newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
461	    &fromlen);
462	if (newtctx->tc_fd < 0) {
463		ret = errno;
464		free(newtctx);
465		return (ret);
466	}
467
468	newtctx->tc_side = TCP_SIDE_SERVER_WORK;
469	newtctx->tc_magic = TCP_CTX_MAGIC;
470	*newctxp = newtctx;
471
472	return (0);
473}
474
475static int
476tcp_wrap(int fd, bool client, void **ctxp)
477{
478
479	return (tcp_setup_wrap(fd,
480	    client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
481}
482
483static int
484tcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
485{
486	struct tcp_ctx *tctx = ctx;
487
488	PJDLOG_ASSERT(tctx != NULL);
489	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
490	PJDLOG_ASSERT(tctx->tc_fd >= 0);
491	PJDLOG_ASSERT(fd == -1);
492
493	return (proto_common_send(tctx->tc_fd, data, size, -1));
494}
495
496static int
497tcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
498{
499	struct tcp_ctx *tctx = ctx;
500
501	PJDLOG_ASSERT(tctx != NULL);
502	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
503	PJDLOG_ASSERT(tctx->tc_fd >= 0);
504	PJDLOG_ASSERT(fdp == NULL);
505
506	return (proto_common_recv(tctx->tc_fd, data, size, NULL));
507}
508
509static int
510tcp_descriptor(const void *ctx)
511{
512	const struct tcp_ctx *tctx = ctx;
513
514	PJDLOG_ASSERT(tctx != NULL);
515	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
516
517	return (tctx->tc_fd);
518}
519
520static bool
521tcp_address_match(const void *ctx, const char *addr)
522{
523	const struct tcp_ctx *tctx = ctx;
524	struct sockaddr_storage sa1, sa2;
525	socklen_t salen;
526
527	PJDLOG_ASSERT(tctx != NULL);
528	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
529
530	if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
531		return (false);
532
533	salen = sizeof(sa2);
534	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) < 0)
535		return (false);
536
537	if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
538		return (false);
539
540	switch (sa1.ss_family) {
541	case AF_INET:
542	    {
543		struct sockaddr_in *sin1, *sin2;
544
545		sin1 = (struct sockaddr_in *)&sa1;
546		sin2 = (struct sockaddr_in *)&sa2;
547
548		return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
549		    sizeof(sin1->sin_addr)) == 0);
550	    }
551	case AF_INET6:
552	    {
553		struct sockaddr_in6 *sin1, *sin2;
554
555		sin1 = (struct sockaddr_in6 *)&sa1;
556		sin2 = (struct sockaddr_in6 *)&sa2;
557
558		return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
559		    sizeof(sin1->sin6_addr)) == 0);
560	    }
561	default:
562		return (false);
563	}
564}
565
566static void
567tcp_local_address(const void *ctx, char *addr, size_t size)
568{
569	const struct tcp_ctx *tctx = ctx;
570	struct sockaddr_storage sa;
571	socklen_t salen;
572
573	PJDLOG_ASSERT(tctx != NULL);
574	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
575
576	salen = sizeof(sa);
577	if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) < 0) {
578		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
579		return;
580	}
581	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
582}
583
584static void
585tcp_remote_address(const void *ctx, char *addr, size_t size)
586{
587	const struct tcp_ctx *tctx = ctx;
588	struct sockaddr_storage sa;
589	socklen_t salen;
590
591	PJDLOG_ASSERT(tctx != NULL);
592	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
593
594	salen = sizeof(sa);
595	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) < 0) {
596		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
597		return;
598	}
599	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
600}
601
602static void
603tcp_close(void *ctx)
604{
605	struct tcp_ctx *tctx = ctx;
606
607	PJDLOG_ASSERT(tctx != NULL);
608	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
609
610	if (tctx->tc_fd >= 0)
611		close(tctx->tc_fd);
612	tctx->tc_magic = 0;
613	free(tctx);
614}
615
616static struct proto tcp_proto = {
617	.prt_name = "tcp",
618	.prt_client = tcp_client,
619	.prt_connect = tcp_connect,
620	.prt_connect_wait = tcp_connect_wait,
621	.prt_server = tcp_server,
622	.prt_accept = tcp_accept,
623	.prt_wrap = tcp_wrap,
624	.prt_send = tcp_send,
625	.prt_recv = tcp_recv,
626	.prt_descriptor = tcp_descriptor,
627	.prt_address_match = tcp_address_match,
628	.prt_local_address = tcp_local_address,
629	.prt_remote_address = tcp_remote_address,
630	.prt_close = tcp_close
631};
632
633static __constructor void
634tcp_ctor(void)
635{
636
637	proto_register(&tcp_proto, true);
638}
639