proto_tcp.c revision 219873
1/*-
2 * Copyright (c) 2009-2010 The FreeBSD Foundation
3 * All rights reserved.
4 *
5 * This software was developed by Pawel Jakub Dawidek under sponsorship from
6 * the FreeBSD Foundation.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 * 1. Redistributions of source code must retain the above copyright
12 *    notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27 * SUCH DAMAGE.
28 */
29
30#include <sys/cdefs.h>
31__FBSDID("$FreeBSD: head/sbin/hastd/proto_tcp4.c 219873 2011-03-22 16:21:11Z pjd $");
32
33#include <sys/param.h>	/* MAXHOSTNAMELEN */
34#include <sys/socket.h>
35
36#include <arpa/inet.h>
37
38#include <netinet/in.h>
39#include <netinet/tcp.h>
40
41#include <errno.h>
42#include <fcntl.h>
43#include <netdb.h>
44#include <stdbool.h>
45#include <stdint.h>
46#include <stdio.h>
47#include <string.h>
48#include <unistd.h>
49
50#include "pjdlog.h"
51#include "proto_impl.h"
52#include "subr.h"
53
54#define	TCP4_CTX_MAGIC	0x7c441c
55struct tcp4_ctx {
56	int			tc_magic;
57	struct sockaddr_in	tc_sin;
58	int			tc_fd;
59	int			tc_side;
60#define	TCP4_SIDE_CLIENT	0
61#define	TCP4_SIDE_SERVER_LISTEN	1
62#define	TCP4_SIDE_SERVER_WORK	2
63};
64
65static int tcp4_connect_wait(void *ctx, int timeout);
66static void tcp4_close(void *ctx);
67
68static in_addr_t
69str2ip(const char *str)
70{
71	struct hostent *hp;
72	in_addr_t ip;
73
74	ip = inet_addr(str);
75	if (ip != INADDR_NONE) {
76		/* It is a valid IP address. */
77		return (ip);
78	}
79	/* Check if it is a valid host name. */
80	hp = gethostbyname(str);
81	if (hp == NULL)
82		return (INADDR_NONE);
83	return (((struct in_addr *)(void *)hp->h_addr)->s_addr);
84}
85
86/*
87 * Function converts the given string to unsigned number.
88 */
89static int
90numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
91{
92	intmax_t digit, num;
93
94	if (str[0] == '\0')
95		goto invalid;	/* Empty string. */
96	num = 0;
97	for (; *str != '\0'; str++) {
98		if (*str < '0' || *str > '9')
99			goto invalid;	/* Non-digit character. */
100		digit = *str - '0';
101		if (num > num * 10 + digit)
102			goto invalid;	/* Overflow. */
103		num = num * 10 + digit;
104		if (num > maxnum)
105			goto invalid;	/* Too big. */
106	}
107	if (num < minnum)
108		goto invalid;	/* Too small. */
109	*nump = num;
110	return (0);
111invalid:
112	errno = EINVAL;
113	return (-1);
114}
115
116static int
117tcp4_addr(const char *addr, int defport, struct sockaddr_in *sinp)
118{
119	char iporhost[MAXHOSTNAMELEN];
120	const char *pp;
121	size_t size;
122	in_addr_t ip;
123
124	if (addr == NULL)
125		return (-1);
126
127	if (strncasecmp(addr, "tcp4://", 7) == 0)
128		addr += 7;
129	else if (strncasecmp(addr, "tcp://", 6) == 0)
130		addr += 6;
131	else {
132		/*
133		 * Because TCP4 is the default assume IP or host is given without
134		 * prefix.
135		 */
136	}
137
138	sinp->sin_family = AF_INET;
139	sinp->sin_len = sizeof(*sinp);
140	/* Extract optional port. */
141	pp = strrchr(addr, ':');
142	if (pp == NULL) {
143		/* Port not given, use the default. */
144		sinp->sin_port = htons(defport);
145	} else {
146		intmax_t port;
147
148		if (numfromstr(pp + 1, 1, 65535, &port) < 0)
149			return (errno);
150		sinp->sin_port = htons(port);
151	}
152	/* Extract host name or IP address. */
153	if (pp == NULL) {
154		size = sizeof(iporhost);
155		if (strlcpy(iporhost, addr, size) >= size)
156			return (ENAMETOOLONG);
157	} else {
158		size = (size_t)(pp - addr + 1);
159		if (size > sizeof(iporhost))
160			return (ENAMETOOLONG);
161		(void)strlcpy(iporhost, addr, size);
162	}
163	/* Convert string (IP address or host name) to in_addr_t. */
164	ip = str2ip(iporhost);
165	if (ip == INADDR_NONE)
166		return (EINVAL);
167	sinp->sin_addr.s_addr = ip;
168
169	return (0);
170}
171
172static int
173tcp4_setup_new(const char *addr, int side, void **ctxp)
174{
175	struct tcp4_ctx *tctx;
176	int ret, nodelay;
177
178	PJDLOG_ASSERT(addr != NULL);
179	PJDLOG_ASSERT(side == TCP4_SIDE_CLIENT ||
180	    side == TCP4_SIDE_SERVER_LISTEN);
181	PJDLOG_ASSERT(ctxp != NULL);
182
183	tctx = malloc(sizeof(*tctx));
184	if (tctx == NULL)
185		return (errno);
186
187	/* Parse given address. */
188	if ((ret = tcp4_addr(addr, PROTO_TCP4_DEFAULT_PORT,
189	    &tctx->tc_sin)) != 0) {
190		free(tctx);
191		return (ret);
192	}
193
194	PJDLOG_ASSERT(tctx->tc_sin.sin_family != AF_UNSPEC);
195
196	tctx->tc_fd = socket(AF_INET, SOCK_STREAM, 0);
197	if (tctx->tc_fd == -1) {
198		ret = errno;
199		free(tctx);
200		return (ret);
201	}
202
203	PJDLOG_ASSERT(tctx->tc_sin.sin_family != AF_UNSPEC);
204
205	/* Socket settings. */
206	nodelay = 1;
207	if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
208	    sizeof(nodelay)) == -1) {
209		pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
210	}
211
212	tctx->tc_side = side;
213	tctx->tc_magic = TCP4_CTX_MAGIC;
214	*ctxp = tctx;
215
216	return (0);
217}
218
219static int
220tcp4_setup_wrap(int fd, int side, void **ctxp)
221{
222	struct tcp4_ctx *tctx;
223
224	PJDLOG_ASSERT(fd >= 0);
225	PJDLOG_ASSERT(side == TCP4_SIDE_CLIENT ||
226	    side == TCP4_SIDE_SERVER_WORK);
227	PJDLOG_ASSERT(ctxp != NULL);
228
229	tctx = malloc(sizeof(*tctx));
230	if (tctx == NULL)
231		return (errno);
232
233	tctx->tc_fd = fd;
234	tctx->tc_sin.sin_family = AF_UNSPEC;
235	tctx->tc_side = side;
236	tctx->tc_magic = TCP4_CTX_MAGIC;
237	*ctxp = tctx;
238
239	return (0);
240}
241
242static int
243tcp4_client(const char *srcaddr, const char *dstaddr, void **ctxp)
244{
245	struct tcp4_ctx *tctx;
246	struct sockaddr_in sin;
247	int ret;
248
249	ret = tcp4_setup_new(dstaddr, TCP4_SIDE_CLIENT, ctxp);
250	if (ret != 0)
251		return (ret);
252	tctx = *ctxp;
253	if (srcaddr == NULL)
254		return (0);
255	ret = tcp4_addr(srcaddr, 0, &sin);
256	if (ret != 0) {
257		tcp4_close(tctx);
258		return (ret);
259	}
260	if (bind(tctx->tc_fd, (struct sockaddr *)&sin, sizeof(sin)) < 0) {
261		ret = errno;
262		tcp4_close(tctx);
263		return (ret);
264	}
265	return (0);
266}
267
268static int
269tcp4_connect(void *ctx, int timeout)
270{
271	struct tcp4_ctx *tctx = ctx;
272	int error, flags;
273
274	PJDLOG_ASSERT(tctx != NULL);
275	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
276	PJDLOG_ASSERT(tctx->tc_side == TCP4_SIDE_CLIENT);
277	PJDLOG_ASSERT(tctx->tc_fd >= 0);
278	PJDLOG_ASSERT(tctx->tc_sin.sin_family != AF_UNSPEC);
279	PJDLOG_ASSERT(timeout >= -1);
280
281	flags = fcntl(tctx->tc_fd, F_GETFL);
282	if (flags == -1) {
283		KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
284		    "fcntl(F_GETFL) failed"));
285		return (errno);
286	}
287	/*
288	 * We make socket non-blocking so we can handle connection timeout
289	 * manually.
290	 */
291	flags |= O_NONBLOCK;
292	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
293		KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
294		    "fcntl(F_SETFL, O_NONBLOCK) failed"));
295		return (errno);
296	}
297
298	if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sin,
299	    sizeof(tctx->tc_sin)) == 0) {
300		if (timeout == -1)
301			return (0);
302		error = 0;
303		goto done;
304	}
305	if (errno != EINPROGRESS) {
306		error = errno;
307		pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
308		goto done;
309	}
310	if (timeout == -1)
311		return (0);
312	return (tcp4_connect_wait(ctx, timeout));
313done:
314	flags &= ~O_NONBLOCK;
315	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
316		if (error == 0)
317			error = errno;
318		pjdlog_common(LOG_DEBUG, 1, errno,
319		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
320	}
321	return (error);
322}
323
324static int
325tcp4_connect_wait(void *ctx, int timeout)
326{
327	struct tcp4_ctx *tctx = ctx;
328	struct timeval tv;
329	fd_set fdset;
330	socklen_t esize;
331	int error, flags, ret;
332
333	PJDLOG_ASSERT(tctx != NULL);
334	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
335	PJDLOG_ASSERT(tctx->tc_side == TCP4_SIDE_CLIENT);
336	PJDLOG_ASSERT(tctx->tc_fd >= 0);
337	PJDLOG_ASSERT(timeout >= 0);
338
339	tv.tv_sec = timeout;
340	tv.tv_usec = 0;
341again:
342	FD_ZERO(&fdset);
343	FD_SET(tctx->tc_fd, &fdset);
344	ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
345	if (ret == 0) {
346		error = ETIMEDOUT;
347		goto done;
348	} else if (ret == -1) {
349		if (errno == EINTR)
350			goto again;
351		error = errno;
352		pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
353		goto done;
354	}
355	PJDLOG_ASSERT(ret > 0);
356	PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
357	esize = sizeof(error);
358	if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
359	    &esize) == -1) {
360		error = errno;
361		pjdlog_common(LOG_DEBUG, 1, errno,
362		    "getsockopt(SO_ERROR) failed");
363		goto done;
364	}
365	if (error != 0) {
366		pjdlog_common(LOG_DEBUG, 1, error,
367		    "getsockopt(SO_ERROR) returned error");
368		goto done;
369	}
370	error = 0;
371done:
372	flags = fcntl(tctx->tc_fd, F_GETFL);
373	if (flags == -1) {
374		if (error == 0)
375			error = errno;
376		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
377		return (error);
378	}
379	flags &= ~O_NONBLOCK;
380	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
381		if (error == 0)
382			error = errno;
383		pjdlog_common(LOG_DEBUG, 1, errno,
384		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
385	}
386	return (error);
387}
388
389static int
390tcp4_server(const char *addr, void **ctxp)
391{
392	struct tcp4_ctx *tctx;
393	int ret, val;
394
395	ret = tcp4_setup_new(addr, TCP4_SIDE_SERVER_LISTEN, ctxp);
396	if (ret != 0)
397		return (ret);
398
399	tctx = *ctxp;
400
401	val = 1;
402	/* Ignore failure. */
403	(void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
404	   sizeof(val));
405
406	PJDLOG_ASSERT(tctx->tc_sin.sin_family != AF_UNSPEC);
407
408	if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sin,
409	    sizeof(tctx->tc_sin)) < 0) {
410		ret = errno;
411		tcp4_close(tctx);
412		return (ret);
413	}
414	if (listen(tctx->tc_fd, 8) < 0) {
415		ret = errno;
416		tcp4_close(tctx);
417		return (ret);
418	}
419
420	return (0);
421}
422
423static int
424tcp4_accept(void *ctx, void **newctxp)
425{
426	struct tcp4_ctx *tctx = ctx;
427	struct tcp4_ctx *newtctx;
428	socklen_t fromlen;
429	int ret;
430
431	PJDLOG_ASSERT(tctx != NULL);
432	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
433	PJDLOG_ASSERT(tctx->tc_side == TCP4_SIDE_SERVER_LISTEN);
434	PJDLOG_ASSERT(tctx->tc_fd >= 0);
435	PJDLOG_ASSERT(tctx->tc_sin.sin_family != AF_UNSPEC);
436
437	newtctx = malloc(sizeof(*newtctx));
438	if (newtctx == NULL)
439		return (errno);
440
441	fromlen = sizeof(tctx->tc_sin);
442	newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sin,
443	    &fromlen);
444	if (newtctx->tc_fd < 0) {
445		ret = errno;
446		free(newtctx);
447		return (ret);
448	}
449
450	newtctx->tc_side = TCP4_SIDE_SERVER_WORK;
451	newtctx->tc_magic = TCP4_CTX_MAGIC;
452	*newctxp = newtctx;
453
454	return (0);
455}
456
457static int
458tcp4_wrap(int fd, bool client, void **ctxp)
459{
460
461	return (tcp4_setup_wrap(fd,
462	    client ? TCP4_SIDE_CLIENT : TCP4_SIDE_SERVER_WORK, ctxp));
463}
464
465static int
466tcp4_send(void *ctx, const unsigned char *data, size_t size, int fd)
467{
468	struct tcp4_ctx *tctx = ctx;
469
470	PJDLOG_ASSERT(tctx != NULL);
471	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
472	PJDLOG_ASSERT(tctx->tc_fd >= 0);
473	PJDLOG_ASSERT(fd == -1);
474
475	return (proto_common_send(tctx->tc_fd, data, size, -1));
476}
477
478static int
479tcp4_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
480{
481	struct tcp4_ctx *tctx = ctx;
482
483	PJDLOG_ASSERT(tctx != NULL);
484	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
485	PJDLOG_ASSERT(tctx->tc_fd >= 0);
486	PJDLOG_ASSERT(fdp == NULL);
487
488	return (proto_common_recv(tctx->tc_fd, data, size, NULL));
489}
490
491static int
492tcp4_descriptor(const void *ctx)
493{
494	const struct tcp4_ctx *tctx = ctx;
495
496	PJDLOG_ASSERT(tctx != NULL);
497	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
498
499	return (tctx->tc_fd);
500}
501
502static bool
503tcp4_address_match(const void *ctx, const char *addr)
504{
505	const struct tcp4_ctx *tctx = ctx;
506	struct sockaddr_in sin;
507	socklen_t sinlen;
508	in_addr_t ip1, ip2;
509
510	PJDLOG_ASSERT(tctx != NULL);
511	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
512
513	if (tcp4_addr(addr, PROTO_TCP4_DEFAULT_PORT, &sin) != 0)
514		return (false);
515	ip1 = sin.sin_addr.s_addr;
516
517	sinlen = sizeof(sin);
518	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sin, &sinlen) < 0)
519		return (false);
520	ip2 = sin.sin_addr.s_addr;
521
522	return (ip1 == ip2);
523}
524
525static void
526tcp4_local_address(const void *ctx, char *addr, size_t size)
527{
528	const struct tcp4_ctx *tctx = ctx;
529	struct sockaddr_in sin;
530	socklen_t sinlen;
531
532	PJDLOG_ASSERT(tctx != NULL);
533	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
534
535	sinlen = sizeof(sin);
536	if (getsockname(tctx->tc_fd, (struct sockaddr *)&sin, &sinlen) < 0) {
537		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
538		return;
539	}
540	PJDLOG_VERIFY(snprintf(addr, size, "tcp4://%S", &sin) < (ssize_t)size);
541}
542
543static void
544tcp4_remote_address(const void *ctx, char *addr, size_t size)
545{
546	const struct tcp4_ctx *tctx = ctx;
547	struct sockaddr_in sin;
548	socklen_t sinlen;
549
550	PJDLOG_ASSERT(tctx != NULL);
551	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
552
553	sinlen = sizeof(sin);
554	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sin, &sinlen) < 0) {
555		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
556		return;
557	}
558	PJDLOG_VERIFY(snprintf(addr, size, "tcp4://%S", &sin) < (ssize_t)size);
559}
560
561static void
562tcp4_close(void *ctx)
563{
564	struct tcp4_ctx *tctx = ctx;
565
566	PJDLOG_ASSERT(tctx != NULL);
567	PJDLOG_ASSERT(tctx->tc_magic == TCP4_CTX_MAGIC);
568
569	if (tctx->tc_fd >= 0)
570		close(tctx->tc_fd);
571	tctx->tc_magic = 0;
572	free(tctx);
573}
574
575static struct proto tcp4_proto = {
576	.prt_name = "tcp4",
577	.prt_client = tcp4_client,
578	.prt_connect = tcp4_connect,
579	.prt_connect_wait = tcp4_connect_wait,
580	.prt_server = tcp4_server,
581	.prt_accept = tcp4_accept,
582	.prt_wrap = tcp4_wrap,
583	.prt_send = tcp4_send,
584	.prt_recv = tcp4_recv,
585	.prt_descriptor = tcp4_descriptor,
586	.prt_address_match = tcp4_address_match,
587	.prt_local_address = tcp4_local_address,
588	.prt_remote_address = tcp4_remote_address,
589	.prt_close = tcp4_close
590};
591
592static __constructor void
593tcp4_ctor(void)
594{
595
596	proto_register(&tcp4_proto, true);
597}
598