proto_tcp.c revision 222118
1204076Spjd/*-
2204076Spjd * Copyright (c) 2009-2010 The FreeBSD Foundation
3222118Spjd * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
4204076Spjd * All rights reserved.
5204076Spjd *
6204076Spjd * This software was developed by Pawel Jakub Dawidek under sponsorship from
7204076Spjd * the FreeBSD Foundation.
8204076Spjd *
9204076Spjd * Redistribution and use in source and binary forms, with or without
10204076Spjd * modification, are permitted provided that the following conditions
11204076Spjd * are met:
12204076Spjd * 1. Redistributions of source code must retain the above copyright
13204076Spjd *    notice, this list of conditions and the following disclaimer.
14204076Spjd * 2. Redistributions in binary form must reproduce the above copyright
15204076Spjd *    notice, this list of conditions and the following disclaimer in the
16204076Spjd *    documentation and/or other materials provided with the distribution.
17204076Spjd *
18204076Spjd * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
19204076Spjd * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20204076Spjd * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21204076Spjd * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
22204076Spjd * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23204076Spjd * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24204076Spjd * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25204076Spjd * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26204076Spjd * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27204076Spjd * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28204076Spjd * SUCH DAMAGE.
29204076Spjd */
30204076Spjd
31204076Spjd#include <sys/cdefs.h>
32204076Spjd__FBSDID("$FreeBSD: head/sbin/hastd/proto_tcp.c 222118 2011-05-20 11:14:05Z pjd $");
33204076Spjd
34204076Spjd#include <sys/param.h>	/* MAXHOSTNAMELEN */
35219873Spjd#include <sys/socket.h>
36204076Spjd
37219873Spjd#include <arpa/inet.h>
38219873Spjd
39204076Spjd#include <netinet/in.h>
40204076Spjd#include <netinet/tcp.h>
41204076Spjd
42204076Spjd#include <errno.h>
43207390Spjd#include <fcntl.h>
44204076Spjd#include <netdb.h>
45204076Spjd#include <stdbool.h>
46204076Spjd#include <stdint.h>
47204076Spjd#include <stdio.h>
48204076Spjd#include <string.h>
49204076Spjd#include <unistd.h>
50204076Spjd
51204076Spjd#include "pjdlog.h"
52204076Spjd#include "proto_impl.h"
53207390Spjd#include "subr.h"
54204076Spjd
55222118Spjd#define	TCP_CTX_MAGIC	0x7c41c
56222116Spjdstruct tcp_ctx {
57204076Spjd	int			tc_magic;
58222118Spjd	struct sockaddr_storage	tc_sa;
59204076Spjd	int			tc_fd;
60204076Spjd	int			tc_side;
61222116Spjd#define	TCP_SIDE_CLIENT		0
62222116Spjd#define	TCP_SIDE_SERVER_LISTEN	1
63222116Spjd#define	TCP_SIDE_SERVER_WORK	2
64204076Spjd};
65204076Spjd
66222116Spjdstatic int tcp_connect_wait(void *ctx, int timeout);
67222116Spjdstatic void tcp_close(void *ctx);
68204076Spjd
69204076Spjd/*
70204076Spjd * Function converts the given string to unsigned number.
71204076Spjd */
72204076Spjdstatic int
73204076Spjdnumfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
74204076Spjd{
75204076Spjd	intmax_t digit, num;
76204076Spjd
77204076Spjd	if (str[0] == '\0')
78204076Spjd		goto invalid;	/* Empty string. */
79204076Spjd	num = 0;
80204076Spjd	for (; *str != '\0'; str++) {
81204076Spjd		if (*str < '0' || *str > '9')
82204076Spjd			goto invalid;	/* Non-digit character. */
83204076Spjd		digit = *str - '0';
84204076Spjd		if (num > num * 10 + digit)
85204076Spjd			goto invalid;	/* Overflow. */
86204076Spjd		num = num * 10 + digit;
87204076Spjd		if (num > maxnum)
88204076Spjd			goto invalid;	/* Too big. */
89204076Spjd	}
90204076Spjd	if (num < minnum)
91204076Spjd		goto invalid;	/* Too small. */
92204076Spjd	*nump = num;
93204076Spjd	return (0);
94204076Spjdinvalid:
95204076Spjd	errno = EINVAL;
96204076Spjd	return (-1);
97204076Spjd}
98204076Spjd
99204076Spjdstatic int
100222118Spjdtcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
101204076Spjd{
102222118Spjd	char iporhost[MAXHOSTNAMELEN], portstr[6];
103222118Spjd	struct addrinfo hints;
104222118Spjd	struct addrinfo *res;
105204076Spjd	const char *pp;
106222118Spjd	intmax_t port;
107204076Spjd	size_t size;
108222118Spjd	int error;
109204076Spjd
110204076Spjd	if (addr == NULL)
111204076Spjd		return (-1);
112204076Spjd
113222118Spjd	bzero(&hints, sizeof(hints));
114222118Spjd	hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
115222118Spjd	hints.ai_family = PF_UNSPEC;
116222118Spjd	hints.ai_socktype = SOCK_STREAM;
117222118Spjd	hints.ai_protocol = IPPROTO_TCP;
118222118Spjd
119222118Spjd	if (strncasecmp(addr, "tcp4://", 7) == 0) {
120204076Spjd		addr += 7;
121222118Spjd		hints.ai_family = PF_INET;
122222118Spjd	} else if (strncasecmp(addr, "tcp6://", 7) == 0) {
123222118Spjd		addr += 7;
124222118Spjd		hints.ai_family = PF_INET6;
125222118Spjd	} else if (strncasecmp(addr, "tcp://", 6) == 0) {
126204076Spjd		addr += 6;
127222118Spjd	} else {
128210870Spjd		/*
129222116Spjd		 * Because TCP is the default assume IP or host is given without
130210870Spjd		 * prefix.
131210870Spjd		 */
132210870Spjd	}
133204076Spjd
134222118Spjd	/*
135222118Spjd	 * Extract optional port.
136222118Spjd	 * There are three cases to consider.
137222118Spjd	 * 1. hostname with port, eg. freefall.freebsd.org:8457
138222118Spjd	 * 2. IPv4 address with port, eg. 192.168.0.101:8457
139222118Spjd	 * 3. IPv6 address with port, eg. [fe80::1]:8457
140222118Spjd	 * We discover IPv6 address by checking for two colons and if port is
141222118Spjd	 * given, the address has to start with [.
142222118Spjd	 */
143222118Spjd	pp = NULL;
144222118Spjd	if (strchr(addr, ':') != strrchr(addr, ':')) {
145222118Spjd		if (addr[0] == '[')
146222118Spjd			pp = strrchr(addr, ':');
147222118Spjd	} else {
148222118Spjd		pp = strrchr(addr, ':');
149222118Spjd	}
150204076Spjd	if (pp == NULL) {
151204076Spjd		/* Port not given, use the default. */
152222118Spjd		port = defport;
153204076Spjd	} else {
154204076Spjd		if (numfromstr(pp + 1, 1, 65535, &port) < 0)
155204076Spjd			return (errno);
156204076Spjd	}
157222118Spjd	(void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
158204076Spjd	/* Extract host name or IP address. */
159204076Spjd	if (pp == NULL) {
160204076Spjd		size = sizeof(iporhost);
161204076Spjd		if (strlcpy(iporhost, addr, size) >= size)
162204076Spjd			return (ENAMETOOLONG);
163222118Spjd	} else if (addr[0] == '[' && pp[-1] == ']') {
164222118Spjd		size = (size_t)(pp - addr - 2 + 1);
165222118Spjd		if (size > sizeof(iporhost))
166222118Spjd			return (ENAMETOOLONG);
167222118Spjd		(void)strlcpy(iporhost, addr + 1, size);
168204076Spjd	} else {
169204076Spjd		size = (size_t)(pp - addr + 1);
170204076Spjd		if (size > sizeof(iporhost))
171204076Spjd			return (ENAMETOOLONG);
172211407Spjd		(void)strlcpy(iporhost, addr, size);
173204076Spjd	}
174222118Spjd
175222118Spjd	error = getaddrinfo(iporhost, portstr, &hints, &res);
176222118Spjd	if (error != 0) {
177222118Spjd		pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
178222118Spjd		    portstr, gai_strerror(error));
179204076Spjd		return (EINVAL);
180222118Spjd	}
181222118Spjd	if (res == NULL)
182222118Spjd		return (ENOENT);
183204076Spjd
184222118Spjd	memcpy(sap, res->ai_addr, res->ai_addrlen);
185222118Spjd
186222118Spjd	freeaddrinfo(res);
187222118Spjd
188204076Spjd	return (0);
189204076Spjd}
190204076Spjd
191204076Spjdstatic int
192222116Spjdtcp_setup_new(const char *addr, int side, void **ctxp)
193204076Spjd{
194222116Spjd	struct tcp_ctx *tctx;
195218158Spjd	int ret, nodelay;
196204076Spjd
197218194Spjd	PJDLOG_ASSERT(addr != NULL);
198222116Spjd	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
199222116Spjd	    side == TCP_SIDE_SERVER_LISTEN);
200218194Spjd	PJDLOG_ASSERT(ctxp != NULL);
201218194Spjd
202204076Spjd	tctx = malloc(sizeof(*tctx));
203204076Spjd	if (tctx == NULL)
204204076Spjd		return (errno);
205204076Spjd
206204076Spjd	/* Parse given address. */
207222118Spjd	if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
208204076Spjd		free(tctx);
209204076Spjd		return (ret);
210204076Spjd	}
211204076Spjd
212222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
213218194Spjd
214222118Spjd	tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
215204076Spjd	if (tctx->tc_fd == -1) {
216204076Spjd		ret = errno;
217204076Spjd		free(tctx);
218204076Spjd		return (ret);
219204076Spjd	}
220204076Spjd
221222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
222219818Spjd
223204076Spjd	/* Socket settings. */
224218158Spjd	nodelay = 1;
225218158Spjd	if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
226218158Spjd	    sizeof(nodelay)) == -1) {
227218194Spjd		pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
228204076Spjd	}
229204076Spjd
230204076Spjd	tctx->tc_side = side;
231222116Spjd	tctx->tc_magic = TCP_CTX_MAGIC;
232204076Spjd	*ctxp = tctx;
233204076Spjd
234204076Spjd	return (0);
235204076Spjd}
236204076Spjd
237204076Spjdstatic int
238222116Spjdtcp_setup_wrap(int fd, int side, void **ctxp)
239218194Spjd{
240222116Spjd	struct tcp_ctx *tctx;
241218194Spjd
242218194Spjd	PJDLOG_ASSERT(fd >= 0);
243222116Spjd	PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
244222116Spjd	    side == TCP_SIDE_SERVER_WORK);
245218194Spjd	PJDLOG_ASSERT(ctxp != NULL);
246218194Spjd
247218194Spjd	tctx = malloc(sizeof(*tctx));
248218194Spjd	if (tctx == NULL)
249218194Spjd		return (errno);
250218194Spjd
251218194Spjd	tctx->tc_fd = fd;
252222118Spjd	tctx->tc_sa.ss_family = AF_UNSPEC;
253218194Spjd	tctx->tc_side = side;
254222116Spjd	tctx->tc_magic = TCP_CTX_MAGIC;
255218194Spjd	*ctxp = tctx;
256218194Spjd
257218194Spjd	return (0);
258218194Spjd}
259218194Spjd
260218194Spjdstatic int
261222116Spjdtcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
262204076Spjd{
263222116Spjd	struct tcp_ctx *tctx;
264222118Spjd	struct sockaddr_storage sa;
265219818Spjd	int ret;
266204076Spjd
267222116Spjd	ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
268219818Spjd	if (ret != 0)
269219818Spjd		return (ret);
270219818Spjd	tctx = *ctxp;
271219818Spjd	if (srcaddr == NULL)
272219818Spjd		return (0);
273222118Spjd	ret = tcp_addr(srcaddr, 0, &sa);
274219818Spjd	if (ret != 0) {
275222116Spjd		tcp_close(tctx);
276219818Spjd		return (ret);
277219818Spjd	}
278222118Spjd	if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) < 0) {
279219818Spjd		ret = errno;
280222116Spjd		tcp_close(tctx);
281219818Spjd		return (ret);
282219818Spjd	}
283219818Spjd	return (0);
284204076Spjd}
285204076Spjd
286204076Spjdstatic int
287222116Spjdtcp_connect(void *ctx, int timeout)
288204076Spjd{
289222116Spjd	struct tcp_ctx *tctx = ctx;
290218193Spjd	int error, flags;
291204076Spjd
292218138Spjd	PJDLOG_ASSERT(tctx != NULL);
293222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
294222116Spjd	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
295218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
296222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
297218193Spjd	PJDLOG_ASSERT(timeout >= -1);
298204076Spjd
299207390Spjd	flags = fcntl(tctx->tc_fd, F_GETFL);
300207390Spjd	if (flags == -1) {
301207390Spjd		KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
302207390Spjd		    "fcntl(F_GETFL) failed"));
303204076Spjd		return (errno);
304204076Spjd	}
305207390Spjd	/*
306211875Spjd	 * We make socket non-blocking so we can handle connection timeout
307211875Spjd	 * manually.
308207390Spjd	 */
309207390Spjd	flags |= O_NONBLOCK;
310207390Spjd	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
311207390Spjd		KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
312207390Spjd		    "fcntl(F_SETFL, O_NONBLOCK) failed"));
313207390Spjd		return (errno);
314207390Spjd	}
315204076Spjd
316222118Spjd	if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
317222118Spjd	    tctx->tc_sa.ss_len) == 0) {
318218193Spjd		if (timeout == -1)
319218193Spjd			return (0);
320207390Spjd		error = 0;
321207390Spjd		goto done;
322207390Spjd	}
323207390Spjd	if (errno != EINPROGRESS) {
324207390Spjd		error = errno;
325207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
326207390Spjd		goto done;
327207390Spjd	}
328218193Spjd	if (timeout == -1)
329218193Spjd		return (0);
330222116Spjd	return (tcp_connect_wait(ctx, timeout));
331218193Spjddone:
332218193Spjd	flags &= ~O_NONBLOCK;
333218193Spjd	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
334218193Spjd		if (error == 0)
335218193Spjd			error = errno;
336218193Spjd		pjdlog_common(LOG_DEBUG, 1, errno,
337218193Spjd		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
338218193Spjd	}
339218193Spjd	return (error);
340218193Spjd}
341218193Spjd
342218193Spjdstatic int
343222116Spjdtcp_connect_wait(void *ctx, int timeout)
344218193Spjd{
345222116Spjd	struct tcp_ctx *tctx = ctx;
346218193Spjd	struct timeval tv;
347218193Spjd	fd_set fdset;
348218193Spjd	socklen_t esize;
349218193Spjd	int error, flags, ret;
350218193Spjd
351218193Spjd	PJDLOG_ASSERT(tctx != NULL);
352222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
353222116Spjd	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
354218193Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
355218193Spjd	PJDLOG_ASSERT(timeout >= 0);
356218193Spjd
357218192Spjd	tv.tv_sec = timeout;
358207390Spjd	tv.tv_usec = 0;
359207390Spjdagain:
360207390Spjd	FD_ZERO(&fdset);
361219864Spjd	FD_SET(tctx->tc_fd, &fdset);
362207390Spjd	ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
363207390Spjd	if (ret == 0) {
364207390Spjd		error = ETIMEDOUT;
365207390Spjd		goto done;
366207390Spjd	} else if (ret == -1) {
367207390Spjd		if (errno == EINTR)
368207390Spjd			goto again;
369207390Spjd		error = errno;
370207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
371207390Spjd		goto done;
372207390Spjd	}
373218138Spjd	PJDLOG_ASSERT(ret > 0);
374218138Spjd	PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
375207390Spjd	esize = sizeof(error);
376207390Spjd	if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
377207390Spjd	    &esize) == -1) {
378207390Spjd		error = errno;
379207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno,
380207390Spjd		    "getsockopt(SO_ERROR) failed");
381207390Spjd		goto done;
382207390Spjd	}
383207390Spjd	if (error != 0) {
384207390Spjd		pjdlog_common(LOG_DEBUG, 1, error,
385207390Spjd		    "getsockopt(SO_ERROR) returned error");
386207390Spjd		goto done;
387207390Spjd	}
388207390Spjd	error = 0;
389207390Spjddone:
390218193Spjd	flags = fcntl(tctx->tc_fd, F_GETFL);
391218193Spjd	if (flags == -1) {
392218193Spjd		if (error == 0)
393218193Spjd			error = errno;
394218193Spjd		pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
395218193Spjd		return (error);
396218193Spjd	}
397207390Spjd	flags &= ~O_NONBLOCK;
398207390Spjd	if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
399207390Spjd		if (error == 0)
400207390Spjd			error = errno;
401207390Spjd		pjdlog_common(LOG_DEBUG, 1, errno,
402207390Spjd		    "fcntl(F_SETFL, ~O_NONBLOCK) failed");
403207390Spjd	}
404207390Spjd	return (error);
405204076Spjd}
406204076Spjd
407204076Spjdstatic int
408222116Spjdtcp_server(const char *addr, void **ctxp)
409204076Spjd{
410222116Spjd	struct tcp_ctx *tctx;
411204076Spjd	int ret, val;
412204076Spjd
413222116Spjd	ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
414204076Spjd	if (ret != 0)
415204076Spjd		return (ret);
416204076Spjd
417204076Spjd	tctx = *ctxp;
418204076Spjd
419204076Spjd	val = 1;
420204076Spjd	/* Ignore failure. */
421204076Spjd	(void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
422204076Spjd	   sizeof(val));
423204076Spjd
424222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
425218194Spjd
426222118Spjd	if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
427222118Spjd	    tctx->tc_sa.ss_len) < 0) {
428204076Spjd		ret = errno;
429222116Spjd		tcp_close(tctx);
430204076Spjd		return (ret);
431204076Spjd	}
432204076Spjd	if (listen(tctx->tc_fd, 8) < 0) {
433204076Spjd		ret = errno;
434222116Spjd		tcp_close(tctx);
435204076Spjd		return (ret);
436204076Spjd	}
437204076Spjd
438204076Spjd	return (0);
439204076Spjd}
440204076Spjd
441204076Spjdstatic int
442222116Spjdtcp_accept(void *ctx, void **newctxp)
443204076Spjd{
444222116Spjd	struct tcp_ctx *tctx = ctx;
445222116Spjd	struct tcp_ctx *newtctx;
446204076Spjd	socklen_t fromlen;
447204076Spjd	int ret;
448204076Spjd
449218138Spjd	PJDLOG_ASSERT(tctx != NULL);
450222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
451222116Spjd	PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
452218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
453222118Spjd	PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
454204076Spjd
455204076Spjd	newtctx = malloc(sizeof(*newtctx));
456204076Spjd	if (newtctx == NULL)
457204076Spjd		return (errno);
458204076Spjd
459222118Spjd	fromlen = tctx->tc_sa.ss_len;
460222118Spjd	newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
461204076Spjd	    &fromlen);
462204076Spjd	if (newtctx->tc_fd < 0) {
463204076Spjd		ret = errno;
464204076Spjd		free(newtctx);
465204076Spjd		return (ret);
466204076Spjd	}
467204076Spjd
468222116Spjd	newtctx->tc_side = TCP_SIDE_SERVER_WORK;
469222116Spjd	newtctx->tc_magic = TCP_CTX_MAGIC;
470204076Spjd	*newctxp = newtctx;
471204076Spjd
472204076Spjd	return (0);
473204076Spjd}
474204076Spjd
475204076Spjdstatic int
476222116Spjdtcp_wrap(int fd, bool client, void **ctxp)
477204076Spjd{
478218194Spjd
479222116Spjd	return (tcp_setup_wrap(fd,
480222116Spjd	    client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
481218194Spjd}
482218194Spjd
483218194Spjdstatic int
484222116Spjdtcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
485218194Spjd{
486222116Spjd	struct tcp_ctx *tctx = ctx;
487204076Spjd
488218138Spjd	PJDLOG_ASSERT(tctx != NULL);
489222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
490218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
491218194Spjd	PJDLOG_ASSERT(fd == -1);
492204076Spjd
493218194Spjd	return (proto_common_send(tctx->tc_fd, data, size, -1));
494204076Spjd}
495204076Spjd
496204076Spjdstatic int
497222116Spjdtcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
498204076Spjd{
499222116Spjd	struct tcp_ctx *tctx = ctx;
500204076Spjd
501218138Spjd	PJDLOG_ASSERT(tctx != NULL);
502222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
503218138Spjd	PJDLOG_ASSERT(tctx->tc_fd >= 0);
504218194Spjd	PJDLOG_ASSERT(fdp == NULL);
505204076Spjd
506218194Spjd	return (proto_common_recv(tctx->tc_fd, data, size, NULL));
507204076Spjd}
508204076Spjd
509204076Spjdstatic int
510222116Spjdtcp_descriptor(const void *ctx)
511204076Spjd{
512222116Spjd	const struct tcp_ctx *tctx = ctx;
513204076Spjd
514218138Spjd	PJDLOG_ASSERT(tctx != NULL);
515222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
516204076Spjd
517204076Spjd	return (tctx->tc_fd);
518204076Spjd}
519204076Spjd
520204076Spjdstatic bool
521222116Spjdtcp_address_match(const void *ctx, const char *addr)
522204076Spjd{
523222116Spjd	const struct tcp_ctx *tctx = ctx;
524222118Spjd	struct sockaddr_storage sa1, sa2;
525222118Spjd	socklen_t salen;
526204076Spjd
527218138Spjd	PJDLOG_ASSERT(tctx != NULL);
528222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
529204076Spjd
530222118Spjd	if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
531204076Spjd		return (false);
532204076Spjd
533222118Spjd	salen = sizeof(sa2);
534222118Spjd	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) < 0)
535204076Spjd		return (false);
536204076Spjd
537222118Spjd	if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
538222118Spjd		return (false);
539222118Spjd
540222118Spjd	switch (sa1.ss_family) {
541222118Spjd	case AF_INET:
542222118Spjd	    {
543222118Spjd		struct sockaddr_in *sin1, *sin2;
544222118Spjd
545222118Spjd		sin1 = (struct sockaddr_in *)&sa1;
546222118Spjd		sin2 = (struct sockaddr_in *)&sa2;
547222118Spjd
548222118Spjd		return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
549222118Spjd		    sizeof(sin1->sin_addr)) == 0);
550222118Spjd	    }
551222118Spjd	case AF_INET6:
552222118Spjd	    {
553222118Spjd		struct sockaddr_in6 *sin1, *sin2;
554222118Spjd
555222118Spjd		sin1 = (struct sockaddr_in6 *)&sa1;
556222118Spjd		sin2 = (struct sockaddr_in6 *)&sa2;
557222118Spjd
558222118Spjd		return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
559222118Spjd		    sizeof(sin1->sin6_addr)) == 0);
560222118Spjd	    }
561222118Spjd	default:
562222118Spjd		return (false);
563222118Spjd	}
564204076Spjd}
565204076Spjd
566204076Spjdstatic void
567222116Spjdtcp_local_address(const void *ctx, char *addr, size_t size)
568204076Spjd{
569222116Spjd	const struct tcp_ctx *tctx = ctx;
570222118Spjd	struct sockaddr_storage sa;
571222118Spjd	socklen_t salen;
572204076Spjd
573218138Spjd	PJDLOG_ASSERT(tctx != NULL);
574222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
575204076Spjd
576222118Spjd	salen = sizeof(sa);
577222118Spjd	if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) < 0) {
578210876Spjd		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
579204076Spjd		return;
580204076Spjd	}
581222118Spjd	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
582204076Spjd}
583204076Spjd
584204076Spjdstatic void
585222116Spjdtcp_remote_address(const void *ctx, char *addr, size_t size)
586204076Spjd{
587222116Spjd	const struct tcp_ctx *tctx = ctx;
588222118Spjd	struct sockaddr_storage sa;
589222118Spjd	socklen_t salen;
590204076Spjd
591218138Spjd	PJDLOG_ASSERT(tctx != NULL);
592222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
593204076Spjd
594222118Spjd	salen = sizeof(sa);
595222118Spjd	if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) < 0) {
596210876Spjd		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
597204076Spjd		return;
598204076Spjd	}
599222118Spjd	PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
600204076Spjd}
601204076Spjd
602204076Spjdstatic void
603222116Spjdtcp_close(void *ctx)
604204076Spjd{
605222116Spjd	struct tcp_ctx *tctx = ctx;
606204076Spjd
607218138Spjd	PJDLOG_ASSERT(tctx != NULL);
608222116Spjd	PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
609204076Spjd
610204076Spjd	if (tctx->tc_fd >= 0)
611204076Spjd		close(tctx->tc_fd);
612204076Spjd	tctx->tc_magic = 0;
613204076Spjd	free(tctx);
614204076Spjd}
615204076Spjd
616222116Spjdstatic struct proto tcp_proto = {
617222116Spjd	.prt_name = "tcp",
618222116Spjd	.prt_client = tcp_client,
619222116Spjd	.prt_connect = tcp_connect,
620222116Spjd	.prt_connect_wait = tcp_connect_wait,
621222116Spjd	.prt_server = tcp_server,
622222116Spjd	.prt_accept = tcp_accept,
623222116Spjd	.prt_wrap = tcp_wrap,
624222116Spjd	.prt_send = tcp_send,
625222116Spjd	.prt_recv = tcp_recv,
626222116Spjd	.prt_descriptor = tcp_descriptor,
627222116Spjd	.prt_address_match = tcp_address_match,
628222116Spjd	.prt_local_address = tcp_local_address,
629222116Spjd	.prt_remote_address = tcp_remote_address,
630222116Spjd	.prt_close = tcp_close
631204076Spjd};
632204076Spjd
633204076Spjdstatic __constructor void
634222116Spjdtcp_ctor(void)
635204076Spjd{
636204076Spjd
637222116Spjd	proto_register(&tcp_proto, true);
638204076Spjd}
639