1204076Spjd/*-
2330449Seadler * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3330449Seadler *
4204076Spjd * Copyright (c) 2009-2010 The FreeBSD Foundation
5204076Spjd * All rights reserved.
6204076Spjd *
7204076Spjd * This software was developed by Pawel Jakub Dawidek under sponsorship from
8204076Spjd * the FreeBSD Foundation.
9204076Spjd *
10204076Spjd * Redistribution and use in source and binary forms, with or without
11204076Spjd * modification, are permitted provided that the following conditions
12204076Spjd * are met:
13204076Spjd * 1. Redistributions of source code must retain the above copyright
14204076Spjd *    notice, this list of conditions and the following disclaimer.
15204076Spjd * 2. Redistributions in binary form must reproduce the above copyright
16204076Spjd *    notice, this list of conditions and the following disclaimer in the
17204076Spjd *    documentation and/or other materials provided with the distribution.
18204076Spjd *
19204076Spjd * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20204076Spjd * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21204076Spjd * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22204076Spjd * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23204076Spjd * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24204076Spjd * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25204076Spjd * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26204076Spjd * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27204076Spjd * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28204076Spjd * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29204076Spjd * SUCH DAMAGE.
30204076Spjd */
31204076Spjd
32204076Spjd#include <sys/cdefs.h>
33204076Spjd__FBSDID("$FreeBSD: stable/11/sbin/hastd/proto_socketpair.c 330449 2018-03-05 07:26:05Z eadler $");
34204076Spjd
35204076Spjd#include <sys/types.h>
36204076Spjd#include <sys/socket.h>
37204076Spjd
38204076Spjd#include <errno.h>
39204076Spjd#include <stdbool.h>
40204076Spjd#include <stdint.h>
41204076Spjd#include <stdio.h>
42204076Spjd#include <string.h>
43204076Spjd#include <unistd.h>
44204076Spjd
45218138Spjd#include "pjdlog.h"
46204076Spjd#include "proto_impl.h"
47204076Spjd
48204076Spjd#define	SP_CTX_MAGIC	0x50c3741
49204076Spjdstruct sp_ctx {
50204076Spjd	int			sp_magic;
51204076Spjd	int			sp_fd[2];
52204076Spjd	int			sp_side;
53204076Spjd#define	SP_SIDE_UNDEF		0
54204076Spjd#define	SP_SIDE_CLIENT		1
55204076Spjd#define	SP_SIDE_SERVER		2
56204076Spjd};
57204076Spjd
58204076Spjdstatic void sp_close(void *ctx);
59204076Spjd
60204076Spjdstatic int
61219818Spjdsp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
62204076Spjd{
63204076Spjd	struct sp_ctx *spctx;
64204076Spjd	int ret;
65204076Spjd
66219818Spjd	if (strcmp(dstaddr, "socketpair://") != 0)
67204076Spjd		return (-1);
68204076Spjd
69219818Spjd	PJDLOG_ASSERT(srcaddr == NULL);
70219818Spjd
71204076Spjd	spctx = malloc(sizeof(*spctx));
72204076Spjd	if (spctx == NULL)
73204076Spjd		return (errno);
74204076Spjd
75229945Spjd	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) == -1) {
76204076Spjd		ret = errno;
77204076Spjd		free(spctx);
78204076Spjd		return (ret);
79204076Spjd	}
80204076Spjd
81204076Spjd	spctx->sp_side = SP_SIDE_UNDEF;
82204076Spjd	spctx->sp_magic = SP_CTX_MAGIC;
83204076Spjd	*ctxp = spctx;
84204076Spjd
85204076Spjd	return (0);
86204076Spjd}
87204076Spjd
88204076Spjdstatic int
89218194Spjdsp_send(void *ctx, const unsigned char *data, size_t size, int fd)
90204076Spjd{
91204076Spjd	struct sp_ctx *spctx = ctx;
92218194Spjd	int sock;
93204076Spjd
94218138Spjd	PJDLOG_ASSERT(spctx != NULL);
95218138Spjd	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
96204076Spjd
97204076Spjd	switch (spctx->sp_side) {
98204076Spjd	case SP_SIDE_UNDEF:
99204076Spjd		/*
100204076Spjd		 * If the first operation done by the caller is proto_send(),
101218194Spjd		 * we assume this is the client.
102204076Spjd		 */
103204076Spjd		/* FALLTHROUGH */
104204076Spjd		spctx->sp_side = SP_SIDE_CLIENT;
105204076Spjd		/* Close other end. */
106204076Spjd		close(spctx->sp_fd[1]);
107218138Spjd		spctx->sp_fd[1] = -1;
108204076Spjd	case SP_SIDE_CLIENT:
109218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
110218194Spjd		sock = spctx->sp_fd[0];
111204076Spjd		break;
112204076Spjd	case SP_SIDE_SERVER:
113218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
114218194Spjd		sock = spctx->sp_fd[1];
115204076Spjd		break;
116204076Spjd	default:
117218138Spjd		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
118204076Spjd	}
119204076Spjd
120212036Spjd	/* Someone is just trying to decide about side. */
121212036Spjd	if (data == NULL)
122212036Spjd		return (0);
123212036Spjd
124218194Spjd	return (proto_common_send(sock, data, size, fd));
125204076Spjd}
126204076Spjd
127204076Spjdstatic int
128218194Spjdsp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
129204076Spjd{
130204076Spjd	struct sp_ctx *spctx = ctx;
131204076Spjd	int fd;
132204076Spjd
133218138Spjd	PJDLOG_ASSERT(spctx != NULL);
134218138Spjd	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
135204076Spjd
136204076Spjd	switch (spctx->sp_side) {
137204076Spjd	case SP_SIDE_UNDEF:
138204076Spjd		/*
139204076Spjd		 * If the first operation done by the caller is proto_recv(),
140218194Spjd		 * we assume this is the server.
141204076Spjd		 */
142204076Spjd		/* FALLTHROUGH */
143204076Spjd		spctx->sp_side = SP_SIDE_SERVER;
144204076Spjd		/* Close other end. */
145204076Spjd		close(spctx->sp_fd[0]);
146218138Spjd		spctx->sp_fd[0] = -1;
147204076Spjd	case SP_SIDE_SERVER:
148218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
149204076Spjd		fd = spctx->sp_fd[1];
150204076Spjd		break;
151204076Spjd	case SP_SIDE_CLIENT:
152218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
153204076Spjd		fd = spctx->sp_fd[0];
154204076Spjd		break;
155204076Spjd	default:
156218138Spjd		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
157204076Spjd	}
158204076Spjd
159212036Spjd	/* Someone is just trying to decide about side. */
160212036Spjd	if (data == NULL)
161212036Spjd		return (0);
162212036Spjd
163218194Spjd	return (proto_common_recv(fd, data, size, fdp));
164204076Spjd}
165204076Spjd
166204076Spjdstatic int
167204076Spjdsp_descriptor(const void *ctx)
168204076Spjd{
169204076Spjd	const struct sp_ctx *spctx = ctx;
170204076Spjd
171218138Spjd	PJDLOG_ASSERT(spctx != NULL);
172218138Spjd	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
173218138Spjd	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
174204076Spjd	    spctx->sp_side == SP_SIDE_SERVER);
175204076Spjd
176204076Spjd	switch (spctx->sp_side) {
177204076Spjd	case SP_SIDE_CLIENT:
178218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
179204076Spjd		return (spctx->sp_fd[0]);
180204076Spjd	case SP_SIDE_SERVER:
181218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
182204076Spjd		return (spctx->sp_fd[1]);
183204076Spjd	}
184204076Spjd
185218138Spjd	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
186204076Spjd}
187204076Spjd
188204076Spjdstatic void
189204076Spjdsp_close(void *ctx)
190204076Spjd{
191204076Spjd	struct sp_ctx *spctx = ctx;
192204076Spjd
193218138Spjd	PJDLOG_ASSERT(spctx != NULL);
194218138Spjd	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
195204076Spjd
196204076Spjd	switch (spctx->sp_side) {
197204076Spjd	case SP_SIDE_UNDEF:
198218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
199204076Spjd		close(spctx->sp_fd[0]);
200218138Spjd		spctx->sp_fd[0] = -1;
201218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
202204076Spjd		close(spctx->sp_fd[1]);
203218138Spjd		spctx->sp_fd[1] = -1;
204204076Spjd		break;
205204076Spjd	case SP_SIDE_CLIENT:
206218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
207204076Spjd		close(spctx->sp_fd[0]);
208218138Spjd		spctx->sp_fd[0] = -1;
209218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
210204076Spjd		break;
211204076Spjd	case SP_SIDE_SERVER:
212218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
213204076Spjd		close(spctx->sp_fd[1]);
214218138Spjd		spctx->sp_fd[1] = -1;
215218138Spjd		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
216204076Spjd		break;
217204076Spjd	default:
218218138Spjd		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
219204076Spjd	}
220204076Spjd
221204076Spjd	spctx->sp_magic = 0;
222204076Spjd	free(spctx);
223204076Spjd}
224204076Spjd
225219873Spjdstatic struct proto sp_proto = {
226219873Spjd	.prt_name = "socketpair",
227219873Spjd	.prt_client = sp_client,
228219873Spjd	.prt_send = sp_send,
229219873Spjd	.prt_recv = sp_recv,
230219873Spjd	.prt_descriptor = sp_descriptor,
231219873Spjd	.prt_close = sp_close
232204076Spjd};
233204076Spjd
234204076Spjdstatic __constructor void
235204076Spjdsp_ctor(void)
236204076Spjd{
237204076Spjd
238210869Spjd	proto_register(&sp_proto, false);
239204076Spjd}
240