1/*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2009-2010 The FreeBSD Foundation
5 * All rights reserved.
6 *
7 * This software was developed by Pawel Jakub Dawidek under sponsorship from
8 * the FreeBSD Foundation.
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 * 1. Redistributions of source code must retain the above copyright
14 *    notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29 * SUCH DAMAGE.
30 */
31
32#include <sys/cdefs.h>
33__FBSDID("$FreeBSD$");
34
35#include <sys/types.h>
36#include <sys/socket.h>
37
38#include <errno.h>
39#include <stdbool.h>
40#include <stdint.h>
41#include <stdio.h>
42#include <string.h>
43#include <unistd.h>
44
45#include "pjdlog.h"
46#include "proto_impl.h"
47
48#define	SP_CTX_MAGIC	0x50c3741
49struct sp_ctx {
50	int			sp_magic;
51	int			sp_fd[2];
52	int			sp_side;
53#define	SP_SIDE_UNDEF		0
54#define	SP_SIDE_CLIENT		1
55#define	SP_SIDE_SERVER		2
56};
57
58static void sp_close(void *ctx);
59
60static int
61sp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
62{
63	struct sp_ctx *spctx;
64	int ret;
65
66	if (strcmp(dstaddr, "socketpair://") != 0)
67		return (-1);
68
69	PJDLOG_ASSERT(srcaddr == NULL);
70
71	spctx = malloc(sizeof(*spctx));
72	if (spctx == NULL)
73		return (errno);
74
75	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) == -1) {
76		ret = errno;
77		free(spctx);
78		return (ret);
79	}
80
81	spctx->sp_side = SP_SIDE_UNDEF;
82	spctx->sp_magic = SP_CTX_MAGIC;
83	*ctxp = spctx;
84
85	return (0);
86}
87
88static int
89sp_send(void *ctx, const unsigned char *data, size_t size, int fd)
90{
91	struct sp_ctx *spctx = ctx;
92	int sock;
93
94	PJDLOG_ASSERT(spctx != NULL);
95	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
96
97	switch (spctx->sp_side) {
98	case SP_SIDE_UNDEF:
99		/*
100		 * If the first operation done by the caller is proto_send(),
101		 * we assume this is the client.
102		 */
103		/* FALLTHROUGH */
104		spctx->sp_side = SP_SIDE_CLIENT;
105		/* Close other end. */
106		close(spctx->sp_fd[1]);
107		spctx->sp_fd[1] = -1;
108	case SP_SIDE_CLIENT:
109		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
110		sock = spctx->sp_fd[0];
111		break;
112	case SP_SIDE_SERVER:
113		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
114		sock = spctx->sp_fd[1];
115		break;
116	default:
117		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
118	}
119
120	/* Someone is just trying to decide about side. */
121	if (data == NULL)
122		return (0);
123
124	return (proto_common_send(sock, data, size, fd));
125}
126
127static int
128sp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
129{
130	struct sp_ctx *spctx = ctx;
131	int fd;
132
133	PJDLOG_ASSERT(spctx != NULL);
134	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
135
136	switch (spctx->sp_side) {
137	case SP_SIDE_UNDEF:
138		/*
139		 * If the first operation done by the caller is proto_recv(),
140		 * we assume this is the server.
141		 */
142		/* FALLTHROUGH */
143		spctx->sp_side = SP_SIDE_SERVER;
144		/* Close other end. */
145		close(spctx->sp_fd[0]);
146		spctx->sp_fd[0] = -1;
147	case SP_SIDE_SERVER:
148		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
149		fd = spctx->sp_fd[1];
150		break;
151	case SP_SIDE_CLIENT:
152		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
153		fd = spctx->sp_fd[0];
154		break;
155	default:
156		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
157	}
158
159	/* Someone is just trying to decide about side. */
160	if (data == NULL)
161		return (0);
162
163	return (proto_common_recv(fd, data, size, fdp));
164}
165
166static int
167sp_descriptor(const void *ctx)
168{
169	const struct sp_ctx *spctx = ctx;
170
171	PJDLOG_ASSERT(spctx != NULL);
172	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
173	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
174	    spctx->sp_side == SP_SIDE_SERVER);
175
176	switch (spctx->sp_side) {
177	case SP_SIDE_CLIENT:
178		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
179		return (spctx->sp_fd[0]);
180	case SP_SIDE_SERVER:
181		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
182		return (spctx->sp_fd[1]);
183	}
184
185	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
186}
187
188static void
189sp_close(void *ctx)
190{
191	struct sp_ctx *spctx = ctx;
192
193	PJDLOG_ASSERT(spctx != NULL);
194	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
195
196	switch (spctx->sp_side) {
197	case SP_SIDE_UNDEF:
198		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
199		close(spctx->sp_fd[0]);
200		spctx->sp_fd[0] = -1;
201		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
202		close(spctx->sp_fd[1]);
203		spctx->sp_fd[1] = -1;
204		break;
205	case SP_SIDE_CLIENT:
206		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
207		close(spctx->sp_fd[0]);
208		spctx->sp_fd[0] = -1;
209		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
210		break;
211	case SP_SIDE_SERVER:
212		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
213		close(spctx->sp_fd[1]);
214		spctx->sp_fd[1] = -1;
215		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
216		break;
217	default:
218		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
219	}
220
221	spctx->sp_magic = 0;
222	free(spctx);
223}
224
225static struct proto sp_proto = {
226	.prt_name = "socketpair",
227	.prt_client = sp_client,
228	.prt_send = sp_send,
229	.prt_recv = sp_recv,
230	.prt_descriptor = sp_descriptor,
231	.prt_close = sp_close
232};
233
234static __constructor void
235sp_ctor(void)
236{
237
238	proto_register(&sp_proto, false);
239}
240