1/*	$NetBSD: nsec3hash.c,v 1.7 2024/02/21 22:51:41 christos Exp $	*/
2
3/*
4 * Copyright (C) Internet Systems Consortium, Inc. ("ISC")
5 *
6 * SPDX-License-Identifier: MPL-2.0
7 *
8 * This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this
10 * file, you can obtain one at https://mozilla.org/MPL/2.0/.
11 *
12 * See the COPYRIGHT file distributed with this work for additional
13 * information regarding copyright ownership.
14 */
15
16#include <stdarg.h>
17#include <stdbool.h>
18#include <stdlib.h>
19
20#include <isc/attributes.h>
21#include <isc/base32.h>
22#include <isc/buffer.h>
23#include <isc/commandline.h>
24#include <isc/file.h>
25#include <isc/hex.h>
26#include <isc/iterated_hash.h>
27#include <isc/print.h>
28#include <isc/result.h>
29#include <isc/string.h>
30#include <isc/types.h>
31#include <isc/util.h>
32
33#include <dns/fixedname.h>
34#include <dns/name.h>
35#include <dns/nsec3.h>
36#include <dns/types.h>
37
38const char *program = "nsec3hash";
39
40noreturn static void
41fatal(const char *format, ...);
42
43static void
44fatal(const char *format, ...) {
45	va_list args;
46
47	fprintf(stderr, "%s: ", program);
48	va_start(args, format);
49	vfprintf(stderr, format, args);
50	va_end(args);
51	fprintf(stderr, "\n");
52	exit(1);
53}
54
55static void
56check_result(isc_result_t result, const char *message) {
57	if (result != ISC_R_SUCCESS) {
58		fatal("%s: %s", message, isc_result_totext(result));
59	}
60}
61
62static void
63usage(void) {
64	fprintf(stderr, "Usage: %s salt algorithm iterations domain\n",
65		program);
66	fprintf(stderr, "       %s -r algorithm flags iterations salt domain\n",
67		program);
68	exit(1);
69}
70
71typedef void
72nsec3printer(unsigned algo, unsigned flags, unsigned iters, const char *saltstr,
73	     const char *domain, const char *digest);
74
75static void
76nsec3hash(nsec3printer *nsec3print, const char *algostr, const char *flagstr,
77	  const char *iterstr, const char *saltstr, const char *domain) {
78	dns_fixedname_t fixed;
79	dns_name_t *name;
80	isc_buffer_t buffer;
81	isc_region_t region;
82	isc_result_t result;
83	unsigned char hash[NSEC3_MAX_HASH_LENGTH];
84	unsigned char salt[DNS_NSEC3_SALTSIZE];
85	unsigned char text[1024];
86	unsigned int hash_alg;
87	unsigned int flags;
88	unsigned int length;
89	unsigned int iterations;
90	unsigned int salt_length;
91	const char dash[] = "-";
92
93	if (strcmp(saltstr, "-") == 0) {
94		salt_length = 0;
95		salt[0] = 0;
96	} else {
97		isc_buffer_init(&buffer, salt, sizeof(salt));
98		result = isc_hex_decodestring(saltstr, &buffer);
99		check_result(result, "isc_hex_decodestring(salt)");
100		salt_length = isc_buffer_usedlength(&buffer);
101		if (salt_length > DNS_NSEC3_SALTSIZE) {
102			fatal("salt too long");
103		}
104		if (salt_length == 0) {
105			saltstr = dash;
106		}
107	}
108	hash_alg = atoi(algostr);
109	if (hash_alg > 255U) {
110		fatal("hash algorithm too large");
111	}
112	flags = flagstr == NULL ? 0 : atoi(flagstr);
113	if (flags > 255U) {
114		fatal("flags too large");
115	}
116	iterations = atoi(iterstr);
117	if (iterations > 0xffffU) {
118		fatal("iterations to large");
119	}
120
121	name = dns_fixedname_initname(&fixed);
122	isc_buffer_constinit(&buffer, domain, strlen(domain));
123	isc_buffer_add(&buffer, strlen(domain));
124	result = dns_name_fromtext(name, &buffer, dns_rootname, 0, NULL);
125	check_result(result, "dns_name_fromtext() failed");
126
127	dns_name_downcase(name, name, NULL);
128	length = isc_iterated_hash(hash, hash_alg, iterations, salt,
129				   salt_length, name->ndata, name->length);
130	if (length == 0) {
131		fatal("isc_iterated_hash failed");
132	}
133	region.base = hash;
134	region.length = length;
135	isc_buffer_init(&buffer, text, sizeof(text));
136	isc_base32hexnp_totext(&region, 1, "", &buffer);
137	isc_buffer_putuint8(&buffer, '\0');
138
139	nsec3print(hash_alg, flags, iterations, saltstr, domain, (char *)text);
140}
141
142static void
143nsec3hash_print(unsigned algo, unsigned flags, unsigned iters,
144		const char *saltstr, const char *domain, const char *digest) {
145	UNUSED(flags);
146	UNUSED(domain);
147
148	fprintf(stdout, "%s (salt=%s, hash=%u, iterations=%u)\n", digest,
149		saltstr, algo, iters);
150}
151
152static void
153nsec3hash_rdata_print(unsigned algo, unsigned flags, unsigned iters,
154		      const char *saltstr, const char *domain,
155		      const char *digest) {
156	fprintf(stdout, "%s NSEC3 %u %u %u %s %s\n", domain, algo, flags, iters,
157		saltstr, digest);
158}
159
160int
161main(int argc, char *argv[]) {
162	bool rdata_format = false;
163	int ch;
164
165	while ((ch = isc_commandline_parse(argc, argv, "-r")) != -1) {
166		switch (ch) {
167		case 'r':
168			rdata_format = true;
169			break;
170		case '-':
171			isc_commandline_index -= 1;
172			goto skip;
173		default:
174			break;
175		}
176	}
177
178skip:
179	argc -= isc_commandline_index;
180	argv += isc_commandline_index;
181
182	if (rdata_format) {
183		if (argc != 5) {
184			usage();
185		}
186		nsec3hash(nsec3hash_rdata_print, argv[0], argv[1], argv[2],
187			  argv[3], argv[4]);
188	} else {
189		if (argc != 4) {
190			usage();
191		}
192		nsec3hash(nsec3hash_print, argv[1], NULL, argv[2], argv[0],
193			  argv[3]);
194	}
195	return (0);
196}
197