1/* $NetBSD: test_server.c,v 1.2 2024/02/21 22:51:13 christos Exp $ */ 2 3/* 4 * Copyright (C) Internet Systems Consortium, Inc. ("ISC") 5 * 6 * SPDX-License-Identifier: MPL-2.0 7 * 8 * This Source Code Form is subject to the terms of the Mozilla Public 9 * License, v. 2.0. If a copy of the MPL was not distributed with this 10 * file, you can obtain one at https://mozilla.org/MPL/2.0/. 11 * 12 * See the COPYRIGHT file distributed with this work for additional 13 * information regarding copyright ownership. 14 */ 15 16#include <getopt.h> 17#include <netinet/in.h> 18#include <signal.h> 19#include <stdbool.h> 20#include <stdio.h> 21#include <stdlib.h> 22#include <strings.h> 23 24#include <isc/managers.h> 25#include <isc/mem.h> 26#include <isc/netaddr.h> 27#include <isc/netmgr.h> 28#include <isc/os.h> 29#include <isc/sockaddr.h> 30#include <isc/string.h> 31#include <isc/util.h> 32 33typedef enum { UDP, TCP, DOT, HTTPS, HTTP } protocol_t; 34 35static const char *protocols[] = { "udp", "tcp", "dot", "https", "http-plain" }; 36 37static isc_mem_t *mctx = NULL; 38static isc_nm_t *netmgr = NULL; 39 40static protocol_t protocol; 41static in_port_t port; 42static isc_netaddr_t netaddr; 43static isc_sockaddr_t sockaddr __attribute__((unused)); 44static int workers; 45 46static isc_tlsctx_t *tls_ctx = NULL; 47 48static void 49read_cb(isc_nmhandle_t *handle, isc_result_t eresult, isc_region_t *region, 50 void *cbarg); 51static void 52send_cb(isc_nmhandle_t *handle, isc_result_t eresult, void *cbarg); 53 54static isc_result_t 55parse_port(const char *input) { 56 char *endptr = NULL; 57 long val = strtol(input, &endptr, 10); 58 59 if ((*endptr != '\0') || (val <= 0) || (val >= 65536)) { 60 return (ISC_R_BADNUMBER); 61 } 62 63 port = (in_port_t)val; 64 65 return (ISC_R_SUCCESS); 66} 67 68static isc_result_t 69parse_protocol(const char *input) { 70 for (size_t i = 0; i < ARRAY_SIZE(protocols); i++) { 71 if (!strcasecmp(input, protocols[i])) { 72 protocol = i; 73 return (ISC_R_SUCCESS); 74 } 75 } 76 77 return (ISC_R_BADNUMBER); 78} 79 80static isc_result_t 81parse_address(const char *input) { 82 struct in6_addr in6; 83 struct in_addr in; 84 85 if (inet_pton(AF_INET6, input, &in6) == 1) { 86 isc_netaddr_fromin6(&netaddr, &in6); 87 return (ISC_R_SUCCESS); 88 } 89 90 if (inet_pton(AF_INET, input, &in) == 1) { 91 isc_netaddr_fromin(&netaddr, &in); 92 return (ISC_R_SUCCESS); 93 } 94 95 return (ISC_R_BADADDRESSFORM); 96} 97 98static int 99parse_workers(const char *input) { 100 char *endptr = NULL; 101 long val = strtol(input, &endptr, 10); 102 103 if ((*endptr != '\0') || (val <= 0) || (val >= 128)) { 104 return (ISC_R_BADNUMBER); 105 } 106 107 workers = val; 108 109 return (ISC_R_SUCCESS); 110} 111 112static void 113parse_options(int argc, char **argv) { 114 char buf[ISC_NETADDR_FORMATSIZE]; 115 116 /* Set defaults */ 117 RUNTIME_CHECK(parse_protocol("UDP") == ISC_R_SUCCESS); 118 RUNTIME_CHECK(parse_port("53000") == ISC_R_SUCCESS); 119 RUNTIME_CHECK(parse_address("::1") == ISC_R_SUCCESS); 120 workers = isc_os_ncpus(); 121 122 while (true) { 123 int c; 124 int option_index = 0; 125 static struct option long_options[] = { 126 { "port", required_argument, NULL, 'p' }, 127 { "address", required_argument, NULL, 'a' }, 128 { "protocol", required_argument, NULL, 'P' }, 129 { "workers", required_argument, NULL, 'w' }, 130 { 0, 0, NULL, 0 } 131 }; 132 133 c = getopt_long(argc, argv, "a:p:P:w:", long_options, 134 &option_index); 135 if (c == -1) { 136 break; 137 } 138 139 switch (c) { 140 case 'a': 141 RUNTIME_CHECK(parse_address(optarg) == ISC_R_SUCCESS); 142 break; 143 144 case 'p': 145 RUNTIME_CHECK(parse_port(optarg) == ISC_R_SUCCESS); 146 break; 147 148 case 'P': 149 RUNTIME_CHECK(parse_protocol(optarg) == ISC_R_SUCCESS); 150 break; 151 152 case 'w': 153 RUNTIME_CHECK(parse_workers(optarg) == ISC_R_SUCCESS); 154 break; 155 156 default: 157 UNREACHABLE(); 158 } 159 } 160 161 isc_sockaddr_fromnetaddr(&sockaddr, &netaddr, port); 162 163 isc_sockaddr_format(&sockaddr, buf, sizeof(buf)); 164 165 printf("Will listen at %s://%s, %d workers\n", protocols[protocol], buf, 166 workers); 167} 168 169static void 170_signal(int sig, void (*handler)(int)) { 171 struct sigaction sa = { .sa_handler = handler }; 172 173 RUNTIME_CHECK(sigfillset(&sa.sa_mask) == 0); 174 RUNTIME_CHECK(sigaction(sig, &sa, NULL) >= 0); 175} 176 177static void 178setup(void) { 179 sigset_t sset; 180 181 _signal(SIGPIPE, SIG_IGN); 182 _signal(SIGHUP, SIG_DFL); 183 _signal(SIGTERM, SIG_DFL); 184 _signal(SIGINT, SIG_DFL); 185 186 RUNTIME_CHECK(sigemptyset(&sset) == 0); 187 RUNTIME_CHECK(sigaddset(&sset, SIGHUP) == 0); 188 RUNTIME_CHECK(sigaddset(&sset, SIGINT) == 0); 189 RUNTIME_CHECK(sigaddset(&sset, SIGTERM) == 0); 190 RUNTIME_CHECK(pthread_sigmask(SIG_BLOCK, &sset, NULL) == 0); 191 192 isc_mem_create(&mctx); 193 194 isc_managers_create(mctx, workers, 0, &netmgr, NULL, NULL); 195} 196 197static void 198teardown(void) { 199 isc_managers_destroy(&netmgr, NULL, NULL); 200 isc_mem_destroy(&mctx); 201 if (tls_ctx) { 202 isc_tlsctx_free(&tls_ctx); 203 } 204} 205 206static void 207test_server_yield(void) { 208 sigset_t sset; 209 int sig; 210 211 RUNTIME_CHECK(sigemptyset(&sset) == 0); 212 RUNTIME_CHECK(sigaddset(&sset, SIGHUP) == 0); 213 RUNTIME_CHECK(sigaddset(&sset, SIGINT) == 0); 214 RUNTIME_CHECK(sigaddset(&sset, SIGTERM) == 0); 215 RUNTIME_CHECK(sigwait(&sset, &sig) == 0); 216 217 fprintf(stderr, "Shutting down...\n"); 218} 219 220static void 221read_cb(isc_nmhandle_t *handle, isc_result_t eresult, isc_region_t *region, 222 void *cbarg) { 223 isc_region_t *reply = NULL; 224 225 REQUIRE(handle != NULL); 226 REQUIRE(eresult == ISC_R_SUCCESS); 227 UNUSED(cbarg); 228 229 fprintf(stderr, "RECEIVED %u bytes\n", region->length); 230 231 if (region->length >= 12) { 232 /* long enough to be a DNS header, set QR bit */ 233 ((uint8_t *)region->base)[2] ^= 0x80; 234 } 235 236 reply = isc_mem_get(mctx, sizeof(isc_region_t) + region->length); 237 reply->length = region->length; 238 reply->base = (uint8_t *)reply + sizeof(isc_region_t); 239 memmove(reply->base, region->base, region->length); 240 241 isc_nm_send(handle, reply, send_cb, reply); 242 return; 243} 244 245static void 246send_cb(isc_nmhandle_t *handle, isc_result_t eresult, void *cbarg) { 247 isc_region_t *reply = cbarg; 248 249 REQUIRE(handle != NULL); 250 REQUIRE(eresult == ISC_R_SUCCESS); 251 252 isc_mem_put(mctx, cbarg, sizeof(isc_region_t) + reply->length); 253} 254 255static isc_result_t 256accept_cb(isc_nmhandle_t *handle, isc_result_t eresult, void *cbarg) { 257 REQUIRE(handle != NULL); 258 REQUIRE(eresult == ISC_R_SUCCESS); 259 UNUSED(cbarg); 260 261 return (ISC_R_SUCCESS); 262} 263 264static void 265run(void) { 266 isc_result_t result; 267 isc_nmsocket_t *sock = NULL; 268 269 switch (protocol) { 270 case UDP: 271 result = isc_nm_listenudp(netmgr, &sockaddr, read_cb, NULL, 0, 272 &sock); 273 break; 274 case TCP: 275 result = isc_nm_listentcpdns(netmgr, &sockaddr, read_cb, NULL, 276 accept_cb, NULL, 0, 0, NULL, 277 &sock); 278 break; 279 case DOT: { 280 isc_tlsctx_createserver(NULL, NULL, &tls_ctx); 281 282 result = isc_nm_listentlsdns(netmgr, &sockaddr, read_cb, NULL, 283 accept_cb, NULL, 0, 0, NULL, 284 tls_ctx, &sock); 285 break; 286 } 287#if HAVE_LIBNGHTTP2 288 case HTTPS: 289 case HTTP: { 290 bool is_https = protocol == HTTPS; 291 isc_nm_http_endpoints_t *eps = NULL; 292 if (is_https) { 293 isc_tlsctx_createserver(NULL, NULL, &tls_ctx); 294 } 295 eps = isc_nm_http_endpoints_new(mctx); 296 result = isc_nm_http_endpoints_add( 297 eps, ISC_NM_HTTP_DEFAULT_PATH, read_cb, NULL, 0); 298 299 if (result == ISC_R_SUCCESS) { 300 result = isc_nm_listenhttp(netmgr, &sockaddr, 0, NULL, 301 tls_ctx, eps, 0, &sock); 302 } 303 isc_nm_http_endpoints_detach(&eps); 304 } break; 305#endif 306 default: 307 UNREACHABLE(); 308 } 309 REQUIRE(result == ISC_R_SUCCESS); 310 311 test_server_yield(); 312 313 isc_nm_stoplistening(sock); 314 isc_nmsocket_close(&sock); 315} 316 317int 318main(int argc, char **argv) { 319 parse_options(argc, argv); 320 321 setup(); 322 323 run(); 324 325 teardown(); 326 327 exit(EXIT_SUCCESS); 328} 329