1/*	$NetBSD: transport.c,v 1.2 2024/02/21 22:52:08 christos Exp $	*/
2
3/*
4 * Copyright (C) Internet Systems Consortium, Inc. ("ISC")
5 *
6 * SPDX-License-Identifier: MPL-2.0
7 *
8 * This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this
10 * file, you can obtain one at https://mozilla.org/MPL/2.0/.
11 *
12 * See the COPYRIGHT file distributed with this work for additional
13 * information regarding copyright ownership.
14 */
15
16#include <inttypes.h>
17
18#include <isc/list.h>
19#include <isc/mem.h>
20#include <isc/refcount.h>
21#include <isc/result.h>
22#include <isc/rwlock.h>
23#include <isc/util.h>
24
25#include <dns/name.h>
26#include <dns/rbt.h>
27#include <dns/transport.h>
28
29#define TRANSPORT_MAGIC	     ISC_MAGIC('T', 'r', 'n', 's')
30#define VALID_TRANSPORT(ptr) ISC_MAGIC_VALID(ptr, TRANSPORT_MAGIC)
31
32#define TRANSPORT_LIST_MAGIC	  ISC_MAGIC('T', 'r', 'L', 's')
33#define VALID_TRANSPORT_LIST(ptr) ISC_MAGIC_VALID(ptr, TRANSPORT_LIST_MAGIC)
34
35struct dns_transport_list {
36	unsigned int magic;
37	isc_refcount_t references;
38	isc_mem_t *mctx;
39	isc_rwlock_t lock;
40	dns_rbt_t *transports[DNS_TRANSPORT_COUNT];
41};
42
43typedef enum ternary { ter_none = 0, ter_true = 1, ter_false = 2 } ternary_t;
44
45struct dns_transport {
46	unsigned int magic;
47	isc_refcount_t references;
48	isc_mem_t *mctx;
49	dns_transport_type_t type;
50	struct {
51		char *tlsname;
52		char *certfile;
53		char *keyfile;
54		char *cafile;
55		char *remote_hostname;
56		char *ciphers;
57		uint32_t protocol_versions;
58		ternary_t prefer_server_ciphers;
59	} tls;
60	struct {
61		char *endpoint;
62		dns_http_mode_t mode;
63	} doh;
64};
65
66static void
67free_dns_transport(void *node, void *arg) {
68	dns_transport_t *transport = node;
69
70	REQUIRE(node != NULL);
71
72	UNUSED(arg);
73
74	dns_transport_detach(&transport);
75}
76
77static isc_result_t
78list_add(dns_transport_list_t *list, const dns_name_t *name,
79	 const dns_transport_type_t type, dns_transport_t *transport) {
80	isc_result_t result;
81	dns_rbt_t *rbt = NULL;
82
83	RWLOCK(&list->lock, isc_rwlocktype_write);
84	rbt = list->transports[type];
85	INSIST(rbt != NULL);
86
87	result = dns_rbt_addname(rbt, name, transport);
88
89	RWUNLOCK(&list->lock, isc_rwlocktype_write);
90
91	return (result);
92}
93
94dns_transport_type_t
95dns_transport_get_type(dns_transport_t *transport) {
96	REQUIRE(VALID_TRANSPORT(transport));
97
98	return (transport->type);
99}
100
101char *
102dns_transport_get_certfile(dns_transport_t *transport) {
103	REQUIRE(VALID_TRANSPORT(transport));
104
105	return (transport->tls.certfile);
106}
107
108char *
109dns_transport_get_keyfile(dns_transport_t *transport) {
110	REQUIRE(VALID_TRANSPORT(transport));
111
112	return (transport->tls.keyfile);
113}
114
115char *
116dns_transport_get_cafile(dns_transport_t *transport) {
117	REQUIRE(VALID_TRANSPORT(transport));
118
119	return (transport->tls.cafile);
120}
121
122char *
123dns_transport_get_remote_hostname(dns_transport_t *transport) {
124	REQUIRE(VALID_TRANSPORT(transport));
125
126	return (transport->tls.remote_hostname);
127}
128
129char *
130dns_transport_get_endpoint(dns_transport_t *transport) {
131	REQUIRE(VALID_TRANSPORT(transport));
132
133	return (transport->doh.endpoint);
134}
135
136dns_http_mode_t
137dns_transport_get_mode(dns_transport_t *transport) {
138	REQUIRE(VALID_TRANSPORT(transport));
139
140	return (transport->doh.mode);
141}
142
143dns_transport_t *
144dns_transport_new(const dns_name_t *name, dns_transport_type_t type,
145		  dns_transport_list_t *list) {
146	dns_transport_t *transport = isc_mem_get(list->mctx,
147						 sizeof(*transport));
148	*transport = (dns_transport_t){ .type = type };
149	isc_refcount_init(&transport->references, 1);
150	isc_mem_attach(list->mctx, &transport->mctx);
151	transport->magic = TRANSPORT_MAGIC;
152
153	list_add(list, name, type, transport);
154
155	return (transport);
156}
157
158void
159dns_transport_set_certfile(dns_transport_t *transport, const char *certfile) {
160	REQUIRE(VALID_TRANSPORT(transport));
161	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
162		transport->type == DNS_TRANSPORT_HTTP);
163
164	if (transport->tls.certfile != NULL) {
165		isc_mem_free(transport->mctx, transport->tls.certfile);
166	}
167
168	if (certfile != NULL) {
169		transport->tls.certfile = isc_mem_strdup(transport->mctx,
170							 certfile);
171	}
172}
173
174void
175dns_transport_set_keyfile(dns_transport_t *transport, const char *keyfile) {
176	REQUIRE(VALID_TRANSPORT(transport));
177	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
178		transport->type == DNS_TRANSPORT_HTTP);
179
180	if (transport->tls.keyfile != NULL) {
181		isc_mem_free(transport->mctx, transport->tls.keyfile);
182	}
183
184	if (keyfile != NULL) {
185		transport->tls.keyfile = isc_mem_strdup(transport->mctx,
186							keyfile);
187	}
188}
189
190void
191dns_transport_set_cafile(dns_transport_t *transport, const char *cafile) {
192	REQUIRE(VALID_TRANSPORT(transport));
193	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
194		transport->type == DNS_TRANSPORT_HTTP);
195
196	if (transport->tls.cafile != NULL) {
197		isc_mem_free(transport->mctx, transport->tls.cafile);
198	}
199
200	if (cafile != NULL) {
201		transport->tls.cafile = isc_mem_strdup(transport->mctx, cafile);
202	}
203}
204
205void
206dns_transport_set_remote_hostname(dns_transport_t *transport,
207				  const char *hostname) {
208	REQUIRE(VALID_TRANSPORT(transport));
209	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
210		transport->type == DNS_TRANSPORT_HTTP);
211
212	if (transport->tls.remote_hostname != NULL) {
213		isc_mem_free(transport->mctx, transport->tls.remote_hostname);
214	}
215
216	if (hostname != NULL) {
217		transport->tls.remote_hostname = isc_mem_strdup(transport->mctx,
218								hostname);
219	}
220}
221
222void
223dns_transport_set_endpoint(dns_transport_t *transport, const char *endpoint) {
224	REQUIRE(VALID_TRANSPORT(transport));
225	REQUIRE(transport->type == DNS_TRANSPORT_HTTP);
226
227	if (transport->doh.endpoint != NULL) {
228		isc_mem_free(transport->mctx, transport->doh.endpoint);
229	}
230
231	if (endpoint != NULL) {
232		transport->doh.endpoint = isc_mem_strdup(transport->mctx,
233							 endpoint);
234	}
235}
236
237void
238dns_transport_set_mode(dns_transport_t *transport, dns_http_mode_t mode) {
239	REQUIRE(VALID_TRANSPORT(transport));
240	REQUIRE(transport->type == DNS_TRANSPORT_HTTP);
241
242	transport->doh.mode = mode;
243}
244
245void
246dns_transport_set_tls_versions(dns_transport_t *transport,
247			       const uint32_t tls_versions) {
248	REQUIRE(VALID_TRANSPORT(transport));
249	REQUIRE(transport->type == DNS_TRANSPORT_HTTP ||
250		transport->type == DNS_TRANSPORT_TLS);
251
252	transport->tls.protocol_versions = tls_versions;
253}
254
255uint32_t
256dns_transport_get_tls_versions(const dns_transport_t *transport) {
257	REQUIRE(VALID_TRANSPORT(transport));
258
259	return (transport->tls.protocol_versions);
260}
261
262void
263dns_transport_set_ciphers(dns_transport_t *transport, const char *ciphers) {
264	REQUIRE(VALID_TRANSPORT(transport));
265	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
266		transport->type == DNS_TRANSPORT_HTTP);
267
268	if (transport->tls.ciphers != NULL) {
269		isc_mem_free(transport->mctx, transport->tls.ciphers);
270	}
271
272	if (ciphers != NULL) {
273		transport->tls.ciphers = isc_mem_strdup(transport->mctx,
274							ciphers);
275	}
276}
277
278void
279dns_transport_set_tlsname(dns_transport_t *transport, const char *tlsname) {
280	REQUIRE(VALID_TRANSPORT(transport));
281	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
282		transport->type == DNS_TRANSPORT_HTTP);
283
284	if (transport->tls.tlsname != NULL) {
285		isc_mem_free(transport->mctx, transport->tls.tlsname);
286	}
287
288	if (tlsname != NULL) {
289		transport->tls.tlsname = isc_mem_strdup(transport->mctx,
290							tlsname);
291	}
292}
293
294char *
295dns_transport_get_ciphers(dns_transport_t *transport) {
296	REQUIRE(VALID_TRANSPORT(transport));
297
298	return (transport->tls.ciphers);
299}
300
301char *
302dns_transport_get_tlsname(dns_transport_t *transport) {
303	REQUIRE(VALID_TRANSPORT(transport));
304
305	return (transport->tls.tlsname);
306}
307
308void
309dns_transport_set_prefer_server_ciphers(dns_transport_t *transport,
310					const bool prefer) {
311	REQUIRE(VALID_TRANSPORT(transport));
312	REQUIRE(transport->type == DNS_TRANSPORT_TLS ||
313		transport->type == DNS_TRANSPORT_HTTP);
314
315	transport->tls.prefer_server_ciphers = prefer ? ter_true : ter_false;
316}
317
318bool
319dns_transport_get_prefer_server_ciphers(const dns_transport_t *transport,
320					bool *preferp) {
321	REQUIRE(VALID_TRANSPORT(transport));
322	REQUIRE(preferp != NULL);
323	if (transport->tls.prefer_server_ciphers == ter_none) {
324		return (false);
325	} else if (transport->tls.prefer_server_ciphers == ter_true) {
326		*preferp = true;
327		return (true);
328	} else if (transport->tls.prefer_server_ciphers == ter_false) {
329		*preferp = false;
330		return (true);
331	}
332
333	UNREACHABLE();
334	return false;
335}
336
337static void
338transport_destroy(dns_transport_t *transport) {
339	isc_refcount_destroy(&transport->references);
340	transport->magic = 0;
341
342	if (transport->doh.endpoint != NULL) {
343		isc_mem_free(transport->mctx, transport->doh.endpoint);
344	}
345	if (transport->tls.remote_hostname != NULL) {
346		isc_mem_free(transport->mctx, transport->tls.remote_hostname);
347	}
348	if (transport->tls.cafile != NULL) {
349		isc_mem_free(transport->mctx, transport->tls.cafile);
350	}
351	if (transport->tls.keyfile != NULL) {
352		isc_mem_free(transport->mctx, transport->tls.keyfile);
353	}
354	if (transport->tls.certfile != NULL) {
355		isc_mem_free(transport->mctx, transport->tls.certfile);
356	}
357	if (transport->tls.ciphers != NULL) {
358		isc_mem_free(transport->mctx, transport->tls.ciphers);
359	}
360
361	if (transport->tls.tlsname != NULL) {
362		isc_mem_free(transport->mctx, transport->tls.tlsname);
363	}
364
365	isc_mem_putanddetach(&transport->mctx, transport, sizeof(*transport));
366}
367
368void
369dns_transport_attach(dns_transport_t *source, dns_transport_t **targetp) {
370	REQUIRE(source != NULL);
371	REQUIRE(targetp != NULL && *targetp == NULL);
372
373	isc_refcount_increment(&source->references);
374
375	*targetp = source;
376}
377
378void
379dns_transport_detach(dns_transport_t **transportp) {
380	dns_transport_t *transport = NULL;
381
382	REQUIRE(transportp != NULL);
383	REQUIRE(VALID_TRANSPORT(*transportp));
384
385	transport = *transportp;
386	*transportp = NULL;
387
388	if (isc_refcount_decrement(&transport->references) == 1) {
389		transport_destroy(transport);
390	}
391}
392
393dns_transport_t *
394dns_transport_find(const dns_transport_type_t type, const dns_name_t *name,
395		   dns_transport_list_t *list) {
396	isc_result_t result;
397	dns_transport_t *transport = NULL;
398	dns_rbt_t *rbt = NULL;
399
400	REQUIRE(VALID_TRANSPORT_LIST(list));
401	REQUIRE(list->transports[type] != NULL);
402
403	rbt = list->transports[type];
404
405	RWLOCK(&list->lock, isc_rwlocktype_read);
406	result = dns_rbt_findname(rbt, name, 0, NULL, (void *)&transport);
407	if (result == ISC_R_SUCCESS) {
408		isc_refcount_increment(&transport->references);
409	}
410	RWUNLOCK(&list->lock, isc_rwlocktype_read);
411
412	return (transport);
413}
414
415dns_transport_list_t *
416dns_transport_list_new(isc_mem_t *mctx) {
417	dns_transport_list_t *list = isc_mem_get(mctx, sizeof(*list));
418
419	*list = (dns_transport_list_t){ 0 };
420
421	isc_rwlock_init(&list->lock, 0, 0);
422
423	isc_mem_attach(mctx, &list->mctx);
424	isc_refcount_init(&list->references, 1);
425
426	list->magic = TRANSPORT_LIST_MAGIC;
427
428	for (size_t type = 0; type < DNS_TRANSPORT_COUNT; type++) {
429		isc_result_t result;
430		result = dns_rbt_create(list->mctx, free_dns_transport, NULL,
431					&list->transports[type]);
432		RUNTIME_CHECK(result == ISC_R_SUCCESS);
433	}
434
435	return (list);
436}
437
438void
439dns_transport_list_attach(dns_transport_list_t *source,
440			  dns_transport_list_t **targetp) {
441	REQUIRE(VALID_TRANSPORT_LIST(source));
442	REQUIRE(targetp != NULL && *targetp == NULL);
443
444	isc_refcount_increment(&source->references);
445
446	*targetp = source;
447}
448
449static void
450transport_list_destroy(dns_transport_list_t *list) {
451	isc_refcount_destroy(&list->references);
452	list->magic = 0;
453
454	for (size_t type = 0; type < DNS_TRANSPORT_COUNT; type++) {
455		if (list->transports[type] != NULL) {
456			dns_rbt_destroy(&list->transports[type]);
457		}
458	}
459	isc_rwlock_destroy(&list->lock);
460	isc_mem_putanddetach(&list->mctx, list, sizeof(*list));
461}
462
463void
464dns_transport_list_detach(dns_transport_list_t **listp) {
465	dns_transport_list_t *list = NULL;
466
467	REQUIRE(listp != NULL);
468	REQUIRE(VALID_TRANSPORT_LIST(*listp));
469
470	list = *listp;
471	*listp = NULL;
472
473	if (isc_refcount_decrement(&list->references) == 1) {
474		transport_list_destroy(list);
475	}
476}
477