1323134Sdes/* 	$OpenBSD: kexfuzz.c,v 1.3 2016/10/11 21:49:54 djm Exp $ */
2313010Sdes/*
3313010Sdes * Fuzz harness for KEX code
4313010Sdes *
5313010Sdes * Placed in the public domain
6313010Sdes */
7313010Sdes
8313010Sdes#include "includes.h"
9313010Sdes
10313010Sdes#include <sys/types.h>
11313010Sdes#include <sys/param.h>
12313010Sdes#include <stdio.h>
13313010Sdes#ifdef HAVE_STDINT_H
14313010Sdes# include <stdint.h>
15313010Sdes#endif
16313010Sdes#include <stdlib.h>
17313010Sdes#include <string.h>
18313010Sdes#include <unistd.h>
19313010Sdes#include <fcntl.h>
20313010Sdes#ifdef HAVE_ERR_H
21313010Sdes# include <err.h>
22313010Sdes#endif
23313010Sdes
24313010Sdes#include "ssherr.h"
25313010Sdes#include "ssh_api.h"
26313010Sdes#include "sshbuf.h"
27313010Sdes#include "packet.h"
28313010Sdes#include "myproposal.h"
29313010Sdes#include "authfile.h"
30323134Sdes#include "log.h"
31313010Sdes
32313010Sdesstruct ssh *active_state = NULL; /* XXX - needed for linking */
33313010Sdes
34313010Sdesvoid kex_tests(void);
35313010Sdesstatic int do_debug = 0;
36313010Sdes
37313010Sdesenum direction { S2C, C2S };
38313010Sdes
39323134Sdesstruct hook_ctx {
40323134Sdes	struct ssh *client, *server, *server2;
41323134Sdes	int *c2s, *s2c;
42323134Sdes	int trigger_direction, packet_index;
43323134Sdes	const char *dump_path;
44323134Sdes	struct sshbuf *replace_data;
45323134Sdes};
46323134Sdes
47313010Sdesstatic int
48323134Sdespacket_hook(struct ssh *ssh, struct sshbuf *packet, u_char *typep, void *_ctx)
49313010Sdes{
50323134Sdes	struct hook_ctx *ctx = (struct hook_ctx *)_ctx;
51323134Sdes	int mydirection = ssh == ctx->client ? S2C : C2S;
52323134Sdes	int *packet_count = mydirection == S2C ? ctx->s2c : ctx->c2s;
53323134Sdes	FILE *dumpfile;
54323134Sdes	int r;
55323134Sdes
56323134Sdes	if (do_debug) {
57323134Sdes		printf("%s packet %d type %u:\n",
58323134Sdes		    mydirection == S2C ? "s2c" : "c2s",
59323134Sdes		    *packet_count, *typep);
60323134Sdes		sshbuf_dump(packet, stdout);
61323134Sdes	}
62323134Sdes	if (mydirection == ctx->trigger_direction &&
63323134Sdes	    ctx->packet_index == *packet_count) {
64323134Sdes		if (ctx->replace_data != NULL) {
65323134Sdes			sshbuf_reset(packet);
66323134Sdes			/* Type is first byte of packet */
67323134Sdes			if ((r = sshbuf_get_u8(ctx->replace_data,
68323134Sdes			    typep)) != 0 ||
69323134Sdes			    (r = sshbuf_putb(packet, ctx->replace_data)) != 0)
70323134Sdes				return r;
71323134Sdes			if (do_debug) {
72323134Sdes				printf("***** replaced packet type %u\n",
73323134Sdes				    *typep);
74323134Sdes				sshbuf_dump(packet, stdout);
75323134Sdes			}
76323134Sdes		} else if (ctx->dump_path != NULL) {
77323134Sdes			if ((dumpfile = fopen(ctx->dump_path, "w+")) == NULL)
78323134Sdes				err(1, "fopen %s", ctx->dump_path);
79323134Sdes			/* Write { type, packet } */
80323134Sdes			if (fwrite(typep, 1, 1, dumpfile) != 1)
81323134Sdes				err(1, "fwrite type %s", ctx->dump_path);
82323134Sdes			if (sshbuf_len(packet) != 0 &&
83323134Sdes			    fwrite(sshbuf_ptr(packet), sshbuf_len(packet),
84323134Sdes			    1, dumpfile) != 1)
85323134Sdes				err(1, "fwrite body %s", ctx->dump_path);
86323134Sdes			if (do_debug) {
87323134Sdes				printf("***** dumped packet type %u len %zu\n",
88323134Sdes				    *typep, sshbuf_len(packet));
89323134Sdes			}
90323134Sdes			fclose(dumpfile);
91323134Sdes			/* No point in continuing */
92323134Sdes			exit(0);
93323134Sdes		}
94323134Sdes	}
95323134Sdes	(*packet_count)++;
96323134Sdes	return 0;
97323134Sdes}
98323134Sdes
99323134Sdesstatic int
100323134Sdesdo_send_and_receive(struct ssh *from, struct ssh *to)
101323134Sdes{
102313010Sdes	u_char type;
103323134Sdes	size_t len;
104313010Sdes	const u_char *buf;
105313010Sdes	int r;
106313010Sdes
107313010Sdes	for (;;) {
108313010Sdes		if ((r = ssh_packet_next(from, &type)) != 0) {
109313010Sdes			fprintf(stderr, "ssh_packet_next: %s\n", ssh_err(r));
110313010Sdes			return r;
111313010Sdes		}
112323134Sdes
113313010Sdes		if (type != 0)
114313010Sdes			return 0;
115313010Sdes		buf = ssh_output_ptr(from, &len);
116313010Sdes		if (len == 0)
117313010Sdes			return 0;
118323134Sdes		if ((r = ssh_input_append(to, buf, len)) != 0) {
119323134Sdes			debug("ssh_input_append: %s", ssh_err(r));
120313010Sdes			return r;
121323134Sdes		}
122323134Sdes		if ((r = ssh_output_consume(from, len)) != 0) {
123323134Sdes			debug("ssh_output_consume: %s", ssh_err(r));
124323134Sdes			return r;
125323134Sdes		}
126313010Sdes	}
127313010Sdes}
128313010Sdes
129313010Sdes/* Minimal test_helper.c scaffholding to make this standalone */
130313010Sdesconst char *in_test = NULL;
131313010Sdes#define TEST_START(a)	\
132313010Sdes	do { \
133313010Sdes		in_test = (a); \
134313010Sdes		if (do_debug) \
135313010Sdes			fprintf(stderr, "test %s starting\n", in_test); \
136313010Sdes	} while (0)
137313010Sdes#define TEST_DONE()	\
138313010Sdes	do { \
139313010Sdes		if (do_debug) \
140313010Sdes			fprintf(stderr, "test %s done\n", \
141313010Sdes			    in_test ? in_test : "???"); \
142313010Sdes		in_test = NULL; \
143313010Sdes	} while(0)
144313010Sdes#define ASSERT_INT_EQ(a, b) \
145313010Sdes	do { \
146313010Sdes		if ((int)(a) != (int)(b)) { \
147313010Sdes			fprintf(stderr, "%s %s:%d " \
148313010Sdes			    "%s (%d) != expected %s (%d)\n", \
149313010Sdes			    in_test ? in_test : "(none)", \
150313010Sdes			    __func__, __LINE__, #a, (int)(a), #b, (int)(b)); \
151313010Sdes			exit(2); \
152313010Sdes		} \
153313010Sdes	} while (0)
154313010Sdes#define ASSERT_INT_GE(a, b) \
155313010Sdes	do { \
156313010Sdes		if ((int)(a) < (int)(b)) { \
157313010Sdes			fprintf(stderr, "%s %s:%d " \
158313010Sdes			    "%s (%d) < expected %s (%d)\n", \
159313010Sdes			    in_test ? in_test : "(none)", \
160313010Sdes			    __func__, __LINE__, #a, (int)(a), #b, (int)(b)); \
161313010Sdes			exit(2); \
162313010Sdes		} \
163313010Sdes	} while (0)
164313010Sdes#define ASSERT_PTR_NE(a, b) \
165313010Sdes	do { \
166313010Sdes		if ((a) == (b)) { \
167313010Sdes			fprintf(stderr, "%s %s:%d " \
168313010Sdes			    "%s (%p) != expected %s (%p)\n", \
169313010Sdes			    in_test ? in_test : "(none)", \
170313010Sdes			    __func__, __LINE__, #a, (a), #b, (b)); \
171313010Sdes			exit(2); \
172313010Sdes		} \
173313010Sdes	} while (0)
174313010Sdes
175313010Sdes
176313010Sdesstatic void
177323134Sdesrun_kex(struct ssh *client, struct ssh *server)
178313010Sdes{
179313010Sdes	int r = 0;
180313010Sdes
181313010Sdes	while (!server->kex->done || !client->kex->done) {
182323134Sdes		if ((r = do_send_and_receive(server, client)) != 0) {
183323134Sdes			debug("do_send_and_receive S2C: %s", ssh_err(r));
184313010Sdes			break;
185323134Sdes		}
186323134Sdes		if ((r = do_send_and_receive(client, server)) != 0) {
187323134Sdes			debug("do_send_and_receive C2S: %s", ssh_err(r));
188313010Sdes			break;
189323134Sdes		}
190313010Sdes	}
191313010Sdes	if (do_debug)
192313010Sdes		printf("done: %s\n", ssh_err(r));
193313010Sdes	ASSERT_INT_EQ(r, 0);
194313010Sdes	ASSERT_INT_EQ(server->kex->done, 1);
195313010Sdes	ASSERT_INT_EQ(client->kex->done, 1);
196313010Sdes}
197313010Sdes
198313010Sdesstatic void
199313010Sdesdo_kex_with_key(const char *kex, struct sshkey *prvkey, int *c2s, int *s2c,
200313010Sdes    int direction, int packet_index,
201313010Sdes    const char *dump_path, struct sshbuf *replace_data)
202313010Sdes{
203313010Sdes	struct ssh *client = NULL, *server = NULL, *server2 = NULL;
204313010Sdes	struct sshkey *pubkey = NULL;
205313010Sdes	struct sshbuf *state;
206313010Sdes	struct kex_params kex_params;
207313010Sdes	char *myproposal[PROPOSAL_MAX] = { KEX_CLIENT };
208313010Sdes	char *keyname = NULL;
209323134Sdes	struct hook_ctx hook_ctx;
210313010Sdes
211313010Sdes	TEST_START("sshkey_from_private");
212313010Sdes	ASSERT_INT_EQ(sshkey_from_private(prvkey, &pubkey), 0);
213313010Sdes	TEST_DONE();
214313010Sdes
215313010Sdes	TEST_START("ssh_init");
216313010Sdes	memcpy(kex_params.proposal, myproposal, sizeof(myproposal));
217313010Sdes	if (kex != NULL)
218313010Sdes		kex_params.proposal[PROPOSAL_KEX_ALGS] = strdup(kex);
219313010Sdes	keyname = strdup(sshkey_ssh_name(prvkey));
220313010Sdes	ASSERT_PTR_NE(keyname, NULL);
221313010Sdes	kex_params.proposal[PROPOSAL_SERVER_HOST_KEY_ALGS] = keyname;
222313010Sdes	ASSERT_INT_EQ(ssh_init(&client, 0, &kex_params), 0);
223313010Sdes	ASSERT_INT_EQ(ssh_init(&server, 1, &kex_params), 0);
224323134Sdes	ASSERT_INT_EQ(ssh_init(&server2, 1, NULL), 0);
225313010Sdes	ASSERT_PTR_NE(client, NULL);
226313010Sdes	ASSERT_PTR_NE(server, NULL);
227323134Sdes	ASSERT_PTR_NE(server2, NULL);
228313010Sdes	TEST_DONE();
229313010Sdes
230323134Sdes	hook_ctx.c2s = c2s;
231323134Sdes	hook_ctx.s2c = s2c;
232323134Sdes	hook_ctx.trigger_direction = direction;
233323134Sdes	hook_ctx.packet_index = packet_index;
234323134Sdes	hook_ctx.dump_path = dump_path;
235323134Sdes	hook_ctx.replace_data = replace_data;
236323134Sdes	hook_ctx.client = client;
237323134Sdes	hook_ctx.server = server;
238323134Sdes	hook_ctx.server2 = server2;
239323134Sdes	ssh_packet_set_input_hook(client, packet_hook, &hook_ctx);
240323134Sdes	ssh_packet_set_input_hook(server, packet_hook, &hook_ctx);
241323134Sdes	ssh_packet_set_input_hook(server2, packet_hook, &hook_ctx);
242323134Sdes
243313010Sdes	TEST_START("ssh_add_hostkey");
244313010Sdes	ASSERT_INT_EQ(ssh_add_hostkey(server, prvkey), 0);
245313010Sdes	ASSERT_INT_EQ(ssh_add_hostkey(client, pubkey), 0);
246313010Sdes	TEST_DONE();
247313010Sdes
248313010Sdes	TEST_START("kex");
249323134Sdes	run_kex(client, server);
250313010Sdes	TEST_DONE();
251313010Sdes
252313010Sdes	TEST_START("rekeying client");
253313010Sdes	ASSERT_INT_EQ(kex_send_kexinit(client), 0);
254323134Sdes	run_kex(client, server);
255313010Sdes	TEST_DONE();
256313010Sdes
257313010Sdes	TEST_START("rekeying server");
258313010Sdes	ASSERT_INT_EQ(kex_send_kexinit(server), 0);
259323134Sdes	run_kex(client, server);
260313010Sdes	TEST_DONE();
261313010Sdes
262313010Sdes	TEST_START("ssh_packet_get_state");
263313010Sdes	state = sshbuf_new();
264313010Sdes	ASSERT_PTR_NE(state, NULL);
265313010Sdes	ASSERT_INT_EQ(ssh_packet_get_state(server, state), 0);
266313010Sdes	ASSERT_INT_GE(sshbuf_len(state), 1);
267313010Sdes	TEST_DONE();
268313010Sdes
269313010Sdes	TEST_START("ssh_packet_set_state");
270313010Sdes	ASSERT_INT_EQ(ssh_add_hostkey(server2, prvkey), 0);
271313010Sdes	kex_free(server2->kex);	/* XXX or should ssh_packet_set_state()? */
272313010Sdes	ASSERT_INT_EQ(ssh_packet_set_state(server2, state), 0);
273313010Sdes	ASSERT_INT_EQ(sshbuf_len(state), 0);
274313010Sdes	sshbuf_free(state);
275313010Sdes	ASSERT_PTR_NE(server2->kex, NULL);
276313010Sdes	/* XXX we need to set the callbacks */
277323134Sdes#ifdef WITH_OPENSSL
278313010Sdes	server2->kex->kex[KEX_DH_GRP1_SHA1] = kexdh_server;
279313010Sdes	server2->kex->kex[KEX_DH_GRP14_SHA1] = kexdh_server;
280323134Sdes	server2->kex->kex[KEX_DH_GRP14_SHA256] = kexdh_server;
281323134Sdes	server2->kex->kex[KEX_DH_GRP16_SHA512] = kexdh_server;
282323134Sdes	server2->kex->kex[KEX_DH_GRP18_SHA512] = kexdh_server;
283313010Sdes	server2->kex->kex[KEX_DH_GEX_SHA1] = kexgex_server;
284313010Sdes	server2->kex->kex[KEX_DH_GEX_SHA256] = kexgex_server;
285323134Sdes# ifdef OPENSSL_HAS_ECC
286313010Sdes	server2->kex->kex[KEX_ECDH_SHA2] = kexecdh_server;
287323134Sdes# endif
288313010Sdes#endif
289313010Sdes	server2->kex->kex[KEX_C25519_SHA256] = kexc25519_server;
290313010Sdes	server2->kex->load_host_public_key = server->kex->load_host_public_key;
291313010Sdes	server2->kex->load_host_private_key = server->kex->load_host_private_key;
292313010Sdes	server2->kex->sign = server->kex->sign;
293313010Sdes	TEST_DONE();
294313010Sdes
295313010Sdes	TEST_START("rekeying server2");
296313010Sdes	ASSERT_INT_EQ(kex_send_kexinit(server2), 0);
297323134Sdes	run_kex(client, server2);
298313010Sdes	ASSERT_INT_EQ(kex_send_kexinit(client), 0);
299323134Sdes	run_kex(client, server2);
300313010Sdes	TEST_DONE();
301313010Sdes
302313010Sdes	TEST_START("cleanup");
303313010Sdes	sshkey_free(pubkey);
304313010Sdes	ssh_free(client);
305313010Sdes	ssh_free(server);
306313010Sdes	ssh_free(server2);
307313010Sdes	free(keyname);
308313010Sdes	TEST_DONE();
309313010Sdes}
310313010Sdes
311313010Sdesstatic void
312313010Sdesusage(void)
313313010Sdes{
314313010Sdes	fprintf(stderr,
315313010Sdes	    "Usage: kexfuzz [-hcdrv] [-D direction] [-f data_file]\n"
316313010Sdes	    "               [-K kex_alg] [-k private_key] [-i packet_index]\n"
317313010Sdes	    "\n"
318313010Sdes	    "Options:\n"
319313010Sdes	    "    -h               Display this help\n"
320313010Sdes	    "    -c               Count packets sent during KEX\n"
321313010Sdes	    "    -d               Dump mode: record KEX packet to data file\n"
322313010Sdes	    "    -r               Replace mode: replace packet with data file\n"
323313010Sdes	    "    -v               Turn on verbose logging\n"
324313010Sdes	    "    -D S2C|C2S       Packet direction for replacement or dump\n"
325313010Sdes	    "    -f data_file     Path to data file for replacement or dump\n"
326313010Sdes	    "    -K kex_alg       Name of KEX algorithm to test (see below)\n"
327313010Sdes	    "    -k private_key   Path to private key file\n"
328313010Sdes	    "    -i packet_index  Index of packet to replace or dump (from 0)\n"
329313010Sdes	    "\n"
330313010Sdes	    "Available KEX algorithms: %s\n", kex_alg_list(' '));
331313010Sdes}
332313010Sdes
333313010Sdesstatic void
334313010Sdesbadusage(const char *bad)
335313010Sdes{
336313010Sdes	fprintf(stderr, "Invalid options\n");
337313010Sdes	fprintf(stderr, "%s\n", bad);
338313010Sdes	usage();
339313010Sdes	exit(1);
340313010Sdes}
341313010Sdes
342313010Sdesint
343313010Sdesmain(int argc, char **argv)
344313010Sdes{
345313010Sdes	int ch, fd, r;
346313010Sdes	int count_flag = 0, dump_flag = 0, replace_flag = 0;
347313010Sdes	int packet_index = -1, direction = -1;
348313010Sdes	int s2c = 0, c2s = 0; /* packet counts */
349313010Sdes	const char *kex = NULL, *kpath = NULL, *data_path = NULL;
350313010Sdes	struct sshkey *key = NULL;
351313010Sdes	struct sshbuf *replace_data = NULL;
352313010Sdes
353313010Sdes	setvbuf(stdout, NULL, _IONBF, 0);
354313010Sdes	while ((ch = getopt(argc, argv, "hcdrvD:f:K:k:i:")) != -1) {
355313010Sdes		switch (ch) {
356313010Sdes		case 'h':
357313010Sdes			usage();
358313010Sdes			return 0;
359313010Sdes		case 'c':
360313010Sdes			count_flag = 1;
361313010Sdes			break;
362313010Sdes		case 'd':
363313010Sdes			dump_flag = 1;
364313010Sdes			break;
365313010Sdes		case 'r':
366313010Sdes			replace_flag = 1;
367313010Sdes			break;
368313010Sdes		case 'v':
369313010Sdes			do_debug = 1;
370313010Sdes			break;
371313010Sdes
372313010Sdes		case 'D':
373313010Sdes			if (strcasecmp(optarg, "s2c") == 0)
374313010Sdes				direction = S2C;
375313010Sdes			else if (strcasecmp(optarg, "c2s") == 0)
376313010Sdes				direction = C2S;
377313010Sdes			else
378313010Sdes				badusage("Invalid direction (-D)");
379313010Sdes			break;
380313010Sdes		case 'f':
381313010Sdes			data_path = optarg;
382313010Sdes			break;
383313010Sdes		case 'K':
384313010Sdes			kex = optarg;
385313010Sdes			break;
386313010Sdes		case 'k':
387313010Sdes			kpath = optarg;
388313010Sdes			break;
389313010Sdes		case 'i':
390313010Sdes			packet_index = atoi(optarg);
391313010Sdes			if (packet_index < 0)
392313010Sdes				badusage("Invalid packet index");
393313010Sdes			break;
394313010Sdes		default:
395313010Sdes			badusage("unsupported flag");
396313010Sdes		}
397313010Sdes	}
398313010Sdes	argc -= optind;
399313010Sdes	argv += optind;
400313010Sdes
401323134Sdes	log_init(argv[0], do_debug ? SYSLOG_LEVEL_DEBUG3 : SYSLOG_LEVEL_INFO,
402323134Sdes	    SYSLOG_FACILITY_USER, 1);
403323134Sdes
404313010Sdes	/* Must select a single mode */
405313010Sdes	if ((count_flag + dump_flag + replace_flag) != 1)
406313010Sdes		badusage("Must select one mode: -c, -d or -r");
407313010Sdes	/* KEX type is mandatory */
408313010Sdes	if (kex == NULL || !kex_names_valid(kex) || strchr(kex, ',') != NULL)
409313010Sdes		badusage("Missing or invalid kex type (-K flag)");
410313010Sdes	/* Valid key is mandatory */
411313010Sdes	if (kpath == NULL)
412313010Sdes		badusage("Missing private key (-k flag)");
413313010Sdes	if ((fd = open(kpath, O_RDONLY)) == -1)
414313010Sdes		err(1, "open %s", kpath);
415313010Sdes	if ((r = sshkey_load_private_type_fd(fd, KEY_UNSPEC, NULL,
416313010Sdes	    &key, NULL)) != 0)
417313010Sdes		errx(1, "Unable to load key %s: %s", kpath, ssh_err(r));
418313010Sdes	close(fd);
419313010Sdes	/* XXX check that it is a private key */
420313010Sdes	/* XXX support certificates */
421313010Sdes	if (key == NULL || key->type == KEY_UNSPEC || key->type == KEY_RSA1)
422313010Sdes		badusage("Invalid key file (-k flag)");
423313010Sdes
424313010Sdes	/* Replace (fuzz) mode */
425313010Sdes	if (replace_flag) {
426313010Sdes		if (packet_index == -1 || direction == -1 || data_path == NULL)
427313010Sdes			badusage("Replace (-r) mode must specify direction "
428313010Sdes			    "(-D) packet index (-i) and data path (-f)");
429313010Sdes		if ((fd = open(data_path, O_RDONLY)) == -1)
430313010Sdes			err(1, "open %s", data_path);
431313010Sdes		replace_data = sshbuf_new();
432313010Sdes		if ((r = sshkey_load_file(fd, replace_data)) != 0)
433313010Sdes			errx(1, "read %s: %s", data_path, ssh_err(r));
434313010Sdes		close(fd);
435313010Sdes	}
436313010Sdes
437313010Sdes	/* Dump mode */
438313010Sdes	if (dump_flag) {
439313010Sdes		if (packet_index == -1 || direction == -1 || data_path == NULL)
440313010Sdes			badusage("Dump (-d) mode must specify direction "
441313010Sdes			    "(-D), packet index (-i) and data path (-f)");
442313010Sdes	}
443313010Sdes
444313010Sdes	/* Count mode needs no further flags */
445313010Sdes
446313010Sdes	do_kex_with_key(kex, key, &c2s, &s2c,
447313010Sdes	    direction, packet_index,
448313010Sdes	    dump_flag ? data_path : NULL,
449313010Sdes	    replace_flag ? replace_data : NULL);
450313010Sdes	sshkey_free(key);
451313010Sdes	sshbuf_free(replace_data);
452313010Sdes
453313010Sdes	if (count_flag) {
454313010Sdes		printf("S2C: %d\n", s2c);
455313010Sdes		printf("C2S: %d\n", c2s);
456313010Sdes	}
457313010Sdes
458313010Sdes	return 0;
459313010Sdes}
460