1/*	$NetBSD: test_server.c,v 1.2 2024/02/21 22:51:13 christos Exp $	*/
2
3/*
4 * Copyright (C) Internet Systems Consortium, Inc. ("ISC")
5 *
6 * SPDX-License-Identifier: MPL-2.0
7 *
8 * This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this
10 * file, you can obtain one at https://mozilla.org/MPL/2.0/.
11 *
12 * See the COPYRIGHT file distributed with this work for additional
13 * information regarding copyright ownership.
14 */
15
16#include <getopt.h>
17#include <netinet/in.h>
18#include <signal.h>
19#include <stdbool.h>
20#include <stdio.h>
21#include <stdlib.h>
22#include <strings.h>
23
24#include <isc/managers.h>
25#include <isc/mem.h>
26#include <isc/netaddr.h>
27#include <isc/netmgr.h>
28#include <isc/os.h>
29#include <isc/sockaddr.h>
30#include <isc/string.h>
31#include <isc/util.h>
32
33typedef enum { UDP, TCP, DOT, HTTPS, HTTP } protocol_t;
34
35static const char *protocols[] = { "udp", "tcp", "dot", "https", "http-plain" };
36
37static isc_mem_t *mctx = NULL;
38static isc_nm_t *netmgr = NULL;
39
40static protocol_t protocol;
41static in_port_t port;
42static isc_netaddr_t netaddr;
43static isc_sockaddr_t sockaddr __attribute__((unused));
44static int workers;
45
46static isc_tlsctx_t *tls_ctx = NULL;
47
48static void
49read_cb(isc_nmhandle_t *handle, isc_result_t eresult, isc_region_t *region,
50	void *cbarg);
51static void
52send_cb(isc_nmhandle_t *handle, isc_result_t eresult, void *cbarg);
53
54static isc_result_t
55parse_port(const char *input) {
56	char *endptr = NULL;
57	long val = strtol(input, &endptr, 10);
58
59	if ((*endptr != '\0') || (val <= 0) || (val >= 65536)) {
60		return (ISC_R_BADNUMBER);
61	}
62
63	port = (in_port_t)val;
64
65	return (ISC_R_SUCCESS);
66}
67
68static isc_result_t
69parse_protocol(const char *input) {
70	for (size_t i = 0; i < ARRAY_SIZE(protocols); i++) {
71		if (!strcasecmp(input, protocols[i])) {
72			protocol = i;
73			return (ISC_R_SUCCESS);
74		}
75	}
76
77	return (ISC_R_BADNUMBER);
78}
79
80static isc_result_t
81parse_address(const char *input) {
82	struct in6_addr in6;
83	struct in_addr in;
84
85	if (inet_pton(AF_INET6, input, &in6) == 1) {
86		isc_netaddr_fromin6(&netaddr, &in6);
87		return (ISC_R_SUCCESS);
88	}
89
90	if (inet_pton(AF_INET, input, &in) == 1) {
91		isc_netaddr_fromin(&netaddr, &in);
92		return (ISC_R_SUCCESS);
93	}
94
95	return (ISC_R_BADADDRESSFORM);
96}
97
98static int
99parse_workers(const char *input) {
100	char *endptr = NULL;
101	long val = strtol(input, &endptr, 10);
102
103	if ((*endptr != '\0') || (val <= 0) || (val >= 128)) {
104		return (ISC_R_BADNUMBER);
105	}
106
107	workers = val;
108
109	return (ISC_R_SUCCESS);
110}
111
112static void
113parse_options(int argc, char **argv) {
114	char buf[ISC_NETADDR_FORMATSIZE];
115
116	/* Set defaults */
117	RUNTIME_CHECK(parse_protocol("UDP") == ISC_R_SUCCESS);
118	RUNTIME_CHECK(parse_port("53000") == ISC_R_SUCCESS);
119	RUNTIME_CHECK(parse_address("::1") == ISC_R_SUCCESS);
120	workers = isc_os_ncpus();
121
122	while (true) {
123		int c;
124		int option_index = 0;
125		static struct option long_options[] = {
126			{ "port", required_argument, NULL, 'p' },
127			{ "address", required_argument, NULL, 'a' },
128			{ "protocol", required_argument, NULL, 'P' },
129			{ "workers", required_argument, NULL, 'w' },
130			{ 0, 0, NULL, 0 }
131		};
132
133		c = getopt_long(argc, argv, "a:p:P:w:", long_options,
134				&option_index);
135		if (c == -1) {
136			break;
137		}
138
139		switch (c) {
140		case 'a':
141			RUNTIME_CHECK(parse_address(optarg) == ISC_R_SUCCESS);
142			break;
143
144		case 'p':
145			RUNTIME_CHECK(parse_port(optarg) == ISC_R_SUCCESS);
146			break;
147
148		case 'P':
149			RUNTIME_CHECK(parse_protocol(optarg) == ISC_R_SUCCESS);
150			break;
151
152		case 'w':
153			RUNTIME_CHECK(parse_workers(optarg) == ISC_R_SUCCESS);
154			break;
155
156		default:
157			UNREACHABLE();
158		}
159	}
160
161	isc_sockaddr_fromnetaddr(&sockaddr, &netaddr, port);
162
163	isc_sockaddr_format(&sockaddr, buf, sizeof(buf));
164
165	printf("Will listen at %s://%s, %d workers\n", protocols[protocol], buf,
166	       workers);
167}
168
169static void
170_signal(int sig, void (*handler)(int)) {
171	struct sigaction sa = { .sa_handler = handler };
172
173	RUNTIME_CHECK(sigfillset(&sa.sa_mask) == 0);
174	RUNTIME_CHECK(sigaction(sig, &sa, NULL) >= 0);
175}
176
177static void
178setup(void) {
179	sigset_t sset;
180
181	_signal(SIGPIPE, SIG_IGN);
182	_signal(SIGHUP, SIG_DFL);
183	_signal(SIGTERM, SIG_DFL);
184	_signal(SIGINT, SIG_DFL);
185
186	RUNTIME_CHECK(sigemptyset(&sset) == 0);
187	RUNTIME_CHECK(sigaddset(&sset, SIGHUP) == 0);
188	RUNTIME_CHECK(sigaddset(&sset, SIGINT) == 0);
189	RUNTIME_CHECK(sigaddset(&sset, SIGTERM) == 0);
190	RUNTIME_CHECK(pthread_sigmask(SIG_BLOCK, &sset, NULL) == 0);
191
192	isc_mem_create(&mctx);
193
194	isc_managers_create(mctx, workers, 0, &netmgr, NULL, NULL);
195}
196
197static void
198teardown(void) {
199	isc_managers_destroy(&netmgr, NULL, NULL);
200	isc_mem_destroy(&mctx);
201	if (tls_ctx) {
202		isc_tlsctx_free(&tls_ctx);
203	}
204}
205
206static void
207test_server_yield(void) {
208	sigset_t sset;
209	int sig;
210
211	RUNTIME_CHECK(sigemptyset(&sset) == 0);
212	RUNTIME_CHECK(sigaddset(&sset, SIGHUP) == 0);
213	RUNTIME_CHECK(sigaddset(&sset, SIGINT) == 0);
214	RUNTIME_CHECK(sigaddset(&sset, SIGTERM) == 0);
215	RUNTIME_CHECK(sigwait(&sset, &sig) == 0);
216
217	fprintf(stderr, "Shutting down...\n");
218}
219
220static void
221read_cb(isc_nmhandle_t *handle, isc_result_t eresult, isc_region_t *region,
222	void *cbarg) {
223	isc_region_t *reply = NULL;
224
225	REQUIRE(handle != NULL);
226	REQUIRE(eresult == ISC_R_SUCCESS);
227	UNUSED(cbarg);
228
229	fprintf(stderr, "RECEIVED %u bytes\n", region->length);
230
231	if (region->length >= 12) {
232		/* long enough to be a DNS header, set QR bit */
233		((uint8_t *)region->base)[2] ^= 0x80;
234	}
235
236	reply = isc_mem_get(mctx, sizeof(isc_region_t) + region->length);
237	reply->length = region->length;
238	reply->base = (uint8_t *)reply + sizeof(isc_region_t);
239	memmove(reply->base, region->base, region->length);
240
241	isc_nm_send(handle, reply, send_cb, reply);
242	return;
243}
244
245static void
246send_cb(isc_nmhandle_t *handle, isc_result_t eresult, void *cbarg) {
247	isc_region_t *reply = cbarg;
248
249	REQUIRE(handle != NULL);
250	REQUIRE(eresult == ISC_R_SUCCESS);
251
252	isc_mem_put(mctx, cbarg, sizeof(isc_region_t) + reply->length);
253}
254
255static isc_result_t
256accept_cb(isc_nmhandle_t *handle, isc_result_t eresult, void *cbarg) {
257	REQUIRE(handle != NULL);
258	REQUIRE(eresult == ISC_R_SUCCESS);
259	UNUSED(cbarg);
260
261	return (ISC_R_SUCCESS);
262}
263
264static void
265run(void) {
266	isc_result_t result;
267	isc_nmsocket_t *sock = NULL;
268
269	switch (protocol) {
270	case UDP:
271		result = isc_nm_listenudp(netmgr, &sockaddr, read_cb, NULL, 0,
272					  &sock);
273		break;
274	case TCP:
275		result = isc_nm_listentcpdns(netmgr, &sockaddr, read_cb, NULL,
276					     accept_cb, NULL, 0, 0, NULL,
277					     &sock);
278		break;
279	case DOT: {
280		isc_tlsctx_createserver(NULL, NULL, &tls_ctx);
281
282		result = isc_nm_listentlsdns(netmgr, &sockaddr, read_cb, NULL,
283					     accept_cb, NULL, 0, 0, NULL,
284					     tls_ctx, &sock);
285		break;
286	}
287#if HAVE_LIBNGHTTP2
288	case HTTPS:
289	case HTTP: {
290		bool is_https = protocol == HTTPS;
291		isc_nm_http_endpoints_t *eps = NULL;
292		if (is_https) {
293			isc_tlsctx_createserver(NULL, NULL, &tls_ctx);
294		}
295		eps = isc_nm_http_endpoints_new(mctx);
296		result = isc_nm_http_endpoints_add(
297			eps, ISC_NM_HTTP_DEFAULT_PATH, read_cb, NULL, 0);
298
299		if (result == ISC_R_SUCCESS) {
300			result = isc_nm_listenhttp(netmgr, &sockaddr, 0, NULL,
301						   tls_ctx, eps, 0, &sock);
302		}
303		isc_nm_http_endpoints_detach(&eps);
304	} break;
305#endif
306	default:
307		UNREACHABLE();
308	}
309	REQUIRE(result == ISC_R_SUCCESS);
310
311	test_server_yield();
312
313	isc_nm_stoplistening(sock);
314	isc_nmsocket_close(&sock);
315}
316
317int
318main(int argc, char **argv) {
319	parse_options(argc, argv);
320
321	setup();
322
323	run();
324
325	teardown();
326
327	exit(EXIT_SUCCESS);
328}
329