1/*	$NetBSD: kex.c,v 1.34 2023/12/20 17:15:20 christos Exp $	*/
2/* $OpenBSD: kex.c,v 1.184 2023/12/18 14:45:49 djm Exp $ */
3
4/*
5 * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include "includes.h"
29__RCSID("$NetBSD: kex.c,v 1.34 2023/12/20 17:15:20 christos Exp $");
30
31#include <sys/param.h>	/* MAX roundup */
32#include <sys/types.h>
33#include <errno.h>
34#include <signal.h>
35#include <stdio.h>
36#include <stdlib.h>
37#include <string.h>
38#include <unistd.h>
39#include <poll.h>
40
41#ifdef WITH_OPENSSL
42#include <openssl/crypto.h>
43#include <openssl/dh.h>
44#endif
45
46#include "ssh.h"
47#include "ssh2.h"
48#include "atomicio.h"
49#include "version.h"
50#include "packet.h"
51#include "compat.h"
52#include "cipher.h"
53#include "sshkey.h"
54#include "kex.h"
55#include "log.h"
56#include "mac.h"
57#include "match.h"
58#include "misc.h"
59#include "dispatch.h"
60#include "packet.h"
61#include "monitor.h"
62#include "myproposal.h"
63
64#include "ssherr.h"
65#include "sshbuf.h"
66#include "digest.h"
67#include "xmalloc.h"
68
69/* prototype */
70static int kex_choose_conf(struct ssh *, uint32_t seq);
71static int kex_input_newkeys(int, u_int32_t, struct ssh *);
72
73static const char * const proposal_names[PROPOSAL_MAX] = {
74	"KEX algorithms",
75	"host key algorithms",
76	"ciphers ctos",
77	"ciphers stoc",
78	"MACs ctos",
79	"MACs stoc",
80	"compression ctos",
81	"compression stoc",
82	"languages ctos",
83	"languages stoc",
84};
85
86struct kexalg {
87	const char *name;
88	u_int type;
89	int ec_nid;
90	int hash_alg;
91};
92static const struct kexalg kexalgs[] = {
93#ifdef WITH_OPENSSL
94	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
95	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
96	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
97	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
98	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
99	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
100	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
101	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
102	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
103	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
104	    SSH_DIGEST_SHA384 },
105	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
106	    SSH_DIGEST_SHA512 },
107#endif
108	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
109	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
110	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
111	    SSH_DIGEST_SHA512 },
112	{ NULL, 0, -1, -1},
113};
114
115char *
116kex_alg_list(char sep)
117{
118	char *ret = NULL, *tmp;
119	size_t nlen, rlen = 0;
120	const struct kexalg *k;
121
122	for (k = kexalgs; k->name != NULL; k++) {
123		if (ret != NULL)
124			ret[rlen++] = sep;
125		nlen = strlen(k->name);
126		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
127			free(ret);
128			return NULL;
129		}
130		ret = tmp;
131		memcpy(ret + rlen, k->name, nlen + 1);
132		rlen += nlen;
133	}
134	return ret;
135}
136
137static const struct kexalg *
138kex_alg_by_name(const char *name)
139{
140	const struct kexalg *k;
141
142	for (k = kexalgs; k->name != NULL; k++) {
143		if (strcmp(k->name, name) == 0)
144			return k;
145	}
146	return NULL;
147}
148
149/* Validate KEX method name list */
150int
151kex_names_valid(const char *names)
152{
153	char *s, *cp, *p;
154
155	if (names == NULL || strcmp(names, "") == 0)
156		return 0;
157	if ((s = cp = strdup(names)) == NULL)
158		return 0;
159	for ((p = strsep(&cp, ",")); p && *p != '\0';
160	    (p = strsep(&cp, ","))) {
161		if (kex_alg_by_name(p) == NULL) {
162			error("Unsupported KEX algorithm \"%.100s\"", p);
163			free(s);
164			return 0;
165		}
166	}
167	debug3("kex names ok: [%s]", names);
168	free(s);
169	return 1;
170}
171
172/* returns non-zero if proposal contains any algorithm from algs */
173static int
174has_any_alg(const char *proposal, const char *algs)
175{
176	char *cp;
177
178	if ((cp = match_list(proposal, algs, NULL)) == NULL)
179		return 0;
180	free(cp);
181	return 1;
182}
183
184/*
185 * Concatenate algorithm names, avoiding duplicates in the process.
186 * Caller must free returned string.
187 */
188char *
189kex_names_cat(const char *a, const char *b)
190{
191	char *ret = NULL, *tmp = NULL, *cp, *p;
192	size_t len;
193
194	if (a == NULL || *a == '\0')
195		return strdup(b);
196	if (b == NULL || *b == '\0')
197		return strdup(a);
198	if (strlen(b) > 1024*1024)
199		return NULL;
200	len = strlen(a) + strlen(b) + 2;
201	if ((tmp = cp = strdup(b)) == NULL ||
202	    (ret = calloc(1, len)) == NULL) {
203		free(tmp);
204		return NULL;
205	}
206	strlcpy(ret, a, len);
207	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
208		if (has_any_alg(ret, p))
209			continue; /* Algorithm already present */
210		if (strlcat(ret, ",", len) >= len ||
211		    strlcat(ret, p, len) >= len) {
212			free(tmp);
213			free(ret);
214			return NULL; /* Shouldn't happen */
215		}
216	}
217	free(tmp);
218	return ret;
219}
220
221/*
222 * Assemble a list of algorithms from a default list and a string from a
223 * configuration file. The user-provided string may begin with '+' to
224 * indicate that it should be appended to the default, '-' that the
225 * specified names should be removed, or '^' that they should be placed
226 * at the head.
227 */
228int
229kex_assemble_names(char **listp, const char *def, const char *all)
230{
231	char *cp, *tmp, *patterns;
232	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
233	int r = SSH_ERR_INTERNAL_ERROR;
234
235	if (listp == NULL || def == NULL || all == NULL)
236		return SSH_ERR_INVALID_ARGUMENT;
237
238	if (*listp == NULL || **listp == '\0') {
239		if ((*listp = strdup(def)) == NULL)
240			return SSH_ERR_ALLOC_FAIL;
241		return 0;
242	}
243
244	list = *listp;
245	*listp = NULL;
246	if (*list == '+') {
247		/* Append names to default list */
248		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
249			r = SSH_ERR_ALLOC_FAIL;
250			goto fail;
251		}
252		free(list);
253		list = tmp;
254	} else if (*list == '-') {
255		/* Remove names from default list */
256		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
257			r = SSH_ERR_ALLOC_FAIL;
258			goto fail;
259		}
260		free(list);
261		/* filtering has already been done */
262		return 0;
263	} else if (*list == '^') {
264		/* Place names at head of default list */
265		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
266			r = SSH_ERR_ALLOC_FAIL;
267			goto fail;
268		}
269		free(list);
270		list = tmp;
271	} else {
272		/* Explicit list, overrides default - just use "list" as is */
273	}
274
275	/*
276	 * The supplied names may be a pattern-list. For the -list case,
277	 * the patterns are applied above. For the +list and explicit list
278	 * cases we need to do it now.
279	 */
280	ret = NULL;
281	if ((patterns = opatterns = strdup(list)) == NULL) {
282		r = SSH_ERR_ALLOC_FAIL;
283		goto fail;
284	}
285	/* Apply positive (i.e. non-negated) patterns from the list */
286	while ((cp = strsep(&patterns, ",")) != NULL) {
287		if (*cp == '!') {
288			/* negated matches are not supported here */
289			r = SSH_ERR_INVALID_ARGUMENT;
290			goto fail;
291		}
292		free(matching);
293		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
294			r = SSH_ERR_ALLOC_FAIL;
295			goto fail;
296		}
297		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
298			r = SSH_ERR_ALLOC_FAIL;
299			goto fail;
300		}
301		free(ret);
302		ret = tmp;
303	}
304	if (ret == NULL || *ret == '\0') {
305		/* An empty name-list is an error */
306		/* XXX better error code? */
307		r = SSH_ERR_INVALID_ARGUMENT;
308		goto fail;
309	}
310
311	/* success */
312	*listp = ret;
313	ret = NULL;
314	r = 0;
315
316 fail:
317	free(matching);
318	free(opatterns);
319	free(list);
320	free(ret);
321	return r;
322}
323
324/*
325 * Fill out a proposal array with dynamically allocated values, which may
326 * be modified as required for compatibility reasons.
327 * Any of the options may be NULL, in which case the default is used.
328 * Array contents must be freed by calling kex_proposal_free_entries.
329 */
330void
331kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
332    const char *kexalgos, const char *ciphers, const char *macs,
333    const char *comp, const char *hkalgs)
334{
335	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
336	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
337	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
338	u_int i;
339	char *cp;
340
341	if (prop == NULL)
342		fatal_f("proposal missing");
343
344	/* Append EXT_INFO signalling to KexAlgorithms */
345	if (kexalgos == NULL)
346		kexalgos = defprop[PROPOSAL_KEX_ALGS];
347	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
348	    "ext-info-s,kex-strict-s-v00@openssh.com" :
349	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
350		fatal_f("kex_names_cat");
351
352	for (i = 0; i < PROPOSAL_MAX; i++) {
353		switch(i) {
354		case PROPOSAL_KEX_ALGS:
355			prop[i] = compat_kex_proposal(ssh, cp);
356			break;
357		case PROPOSAL_ENC_ALGS_CTOS:
358		case PROPOSAL_ENC_ALGS_STOC:
359			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
360			break;
361		case PROPOSAL_MAC_ALGS_CTOS:
362		case PROPOSAL_MAC_ALGS_STOC:
363			prop[i]  = xstrdup(macs ? macs : defprop[i]);
364			break;
365		case PROPOSAL_COMP_ALGS_CTOS:
366		case PROPOSAL_COMP_ALGS_STOC:
367			prop[i] = xstrdup(comp ? comp : defprop[i]);
368			break;
369		case PROPOSAL_SERVER_HOST_KEY_ALGS:
370			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
371			break;
372		default:
373			prop[i] = xstrdup(defprop[i]);
374		}
375	}
376	free(cp);
377}
378
379void
380kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
381{
382	u_int i;
383
384	for (i = 0; i < PROPOSAL_MAX; i++)
385		free(prop[i]);
386}
387
388/* put algorithm proposal into buffer */
389int
390kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
391{
392	u_int i;
393	int r;
394
395	sshbuf_reset(b);
396
397	/*
398	 * add a dummy cookie, the cookie will be overwritten by
399	 * kex_send_kexinit(), each time a kexinit is set
400	 */
401	for (i = 0; i < KEX_COOKIE_LEN; i++) {
402		if ((r = sshbuf_put_u8(b, 0)) != 0)
403			return r;
404	}
405	for (i = 0; i < PROPOSAL_MAX; i++) {
406		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
407			return r;
408	}
409	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
410	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
411		return r;
412	return 0;
413}
414
415/* parse buffer and return algorithm proposal */
416int
417kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
418{
419	struct sshbuf *b = NULL;
420	u_char v;
421	u_int i;
422	char **proposal = NULL;
423	int r;
424
425	*propp = NULL;
426	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
427		return SSH_ERR_ALLOC_FAIL;
428	if ((b = sshbuf_fromb(raw)) == NULL) {
429		r = SSH_ERR_ALLOC_FAIL;
430		goto out;
431	}
432	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
433		error_fr(r, "consume cookie");
434		goto out;
435	}
436	/* extract kex init proposal strings */
437	for (i = 0; i < PROPOSAL_MAX; i++) {
438		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
439			error_fr(r, "parse proposal %u", i);
440			goto out;
441		}
442		debug2("%s: %s", proposal_names[i], proposal[i]);
443	}
444	/* first kex follows / reserved */
445	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
446	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
447		error_fr(r, "parse");
448		goto out;
449	}
450	if (first_kex_follows != NULL)
451		*first_kex_follows = v;
452	debug2("first_kex_follows %d ", v);
453	debug2("reserved %u ", i);
454	r = 0;
455	*propp = proposal;
456 out:
457	if (r != 0 && proposal != NULL)
458		kex_prop_free(proposal);
459	sshbuf_free(b);
460	return r;
461}
462
463void
464kex_prop_free(char **proposal)
465{
466	u_int i;
467
468	if (proposal == NULL)
469		return;
470	for (i = 0; i < PROPOSAL_MAX; i++)
471		free(proposal[i]);
472	free(proposal);
473}
474
475int
476kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
477{
478	int r;
479
480	/* If in strict mode, any unexpected message is an error */
481	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
482		ssh_packet_disconnect(ssh, "strict KEX violation: "
483		    "unexpected packet type %u (seqnr %u)", type, seq);
484	}
485	error_f("type %u seq %u", type, seq);
486	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
487	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
488	    (r = sshpkt_send(ssh)) != 0)
489		return r;
490	return 0;
491}
492
493static void
494kex_reset_dispatch(struct ssh *ssh)
495{
496	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
497	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
498}
499
500void
501kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
502{
503	char *alg, *oalgs, *algs, *sigalgs;
504	const char *sigalg;
505
506	/*
507	 * NB. allowed algorithms may contain certificate algorithms that
508	 * map to a specific plain signature type, e.g.
509	 * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
510	 * We need to be careful here to match these, retain the mapping
511	 * and only add each signature algorithm once.
512	 */
513	if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
514		fatal_f("sshkey_alg_list failed");
515	oalgs = algs = xstrdup(allowed_algs);
516	free(ssh->kex->server_sig_algs);
517	ssh->kex->server_sig_algs = NULL;
518	for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
519	    (alg = strsep(&algs, ","))) {
520		if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
521			continue;
522		if (!has_any_alg(sigalg, sigalgs))
523			continue;
524		/* Don't add an algorithm twice. */
525		if (ssh->kex->server_sig_algs != NULL &&
526		    has_any_alg(sigalg, ssh->kex->server_sig_algs))
527			continue;
528		xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
529	}
530	free(oalgs);
531	free(sigalgs);
532	if (ssh->kex->server_sig_algs == NULL)
533		ssh->kex->server_sig_algs = xstrdup("");
534}
535
536static int
537kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
538{
539	int r;
540
541	if (ssh->kex->server_sig_algs == NULL &&
542	    (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
543		return SSH_ERR_ALLOC_FAIL;
544	if ((r = sshbuf_put_u32(m, 3)) != 0 ||
545	    (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
546	    (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
547	    (r = sshbuf_put_cstring(m,
548	    "publickey-hostbound@openssh.com")) != 0 ||
549	    (r = sshbuf_put_cstring(m, "0")) != 0 ||
550	    (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
551	    (r = sshbuf_put_cstring(m, "0")) != 0) {
552		error_fr(r, "compose");
553		return r;
554	}
555	return 0;
556}
557
558static int
559kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
560{
561	int r;
562
563	if ((r = sshbuf_put_u32(m, 1)) != 0 ||
564	    (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
565	    (r = sshbuf_put_cstring(m, "0")) != 0) {
566		error_fr(r, "compose");
567		goto out;
568	}
569	/* success */
570	r = 0;
571 out:
572	return r;
573}
574
575static int
576kex_maybe_send_ext_info(struct ssh *ssh)
577{
578	int r;
579	struct sshbuf *m = NULL;
580
581	if ((ssh->kex->flags & KEX_INITIAL) == 0)
582		return 0;
583	if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
584		return 0;
585
586	/* Compose EXT_INFO packet. */
587	if ((m = sshbuf_new()) == NULL)
588		fatal_f("sshbuf_new failed");
589	if (ssh->kex->ext_info_c &&
590	    (r = kex_compose_ext_info_server(ssh, m)) != 0)
591		goto fail;
592	if (ssh->kex->ext_info_s &&
593	    (r = kex_compose_ext_info_client(ssh, m)) != 0)
594		goto fail;
595
596	/* Send the actual KEX_INFO packet */
597	debug("Sending SSH2_MSG_EXT_INFO");
598	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
599	    (r = sshpkt_putb(ssh, m)) != 0 ||
600	    (r = sshpkt_send(ssh)) != 0) {
601		error_f("send EXT_INFO");
602		goto fail;
603	}
604
605	r = 0;
606
607 fail:
608	sshbuf_free(m);
609	return r;
610}
611
612int
613kex_server_update_ext_info(struct ssh *ssh)
614{
615	int r;
616
617	if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
618		return 0;
619
620	debug_f("Sending SSH2_MSG_EXT_INFO");
621	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
622	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
623	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
624	    (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
625	    (r = sshpkt_send(ssh)) != 0) {
626		error_f("send EXT_INFO");
627		return r;
628	}
629	return 0;
630}
631
632int
633kex_send_newkeys(struct ssh *ssh)
634{
635	int r;
636
637	kex_reset_dispatch(ssh);
638	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
639	    (r = sshpkt_send(ssh)) != 0)
640		return r;
641	debug("SSH2_MSG_NEWKEYS sent");
642	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
643	if ((r = kex_maybe_send_ext_info(ssh)) != 0)
644		return r;
645	debug("expecting SSH2_MSG_NEWKEYS");
646	return 0;
647}
648
649/* Check whether an ext_info value contains the expected version string */
650static int
651kex_ext_info_check_ver(struct kex *kex, const char *name,
652    const u_char *val, size_t len, const char *want_ver, u_int flag)
653{
654	if (memchr(val, '\0', len) != NULL) {
655		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
656		return SSH_ERR_INVALID_FORMAT;
657	}
658	debug_f("%s=<%s>", name, val);
659	if (strcmp((const char *)val, want_ver) == 0)
660		kex->flags |= flag;
661	else
662		debug_f("unsupported version of %s extension", name);
663	return 0;
664}
665
666static int
667kex_ext_info_client_parse(struct ssh *ssh, const char *name,
668    const u_char *value, size_t vlen)
669{
670	int r;
671
672	/* NB. some messages are only accepted in the initial EXT_INFO */
673	if (strcmp(name, "server-sig-algs") == 0) {
674		/* Ensure no \0 lurking in value */
675		if (memchr(value, '\0', vlen) != NULL) {
676			error_f("nul byte in %s", name);
677			return SSH_ERR_INVALID_FORMAT;
678		}
679		debug_f("%s=<%s>", name, value);
680		free(ssh->kex->server_sig_algs);
681		ssh->kex->server_sig_algs = xstrdup((const char *)value);
682	} else if (ssh->kex->ext_info_received == 1 &&
683	    strcmp(name, "publickey-hostbound@openssh.com") == 0) {
684		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
685		    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
686			return r;
687		}
688	} else if (ssh->kex->ext_info_received == 1 &&
689	    strcmp(name, "ping@openssh.com") == 0) {
690		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
691		    "0", KEX_HAS_PING)) != 0) {
692			return r;
693		}
694	} else
695		debug_f("%s (unrecognised)", name);
696
697	return 0;
698}
699
700static int
701kex_ext_info_server_parse(struct ssh *ssh, const char *name,
702    const u_char *value, size_t vlen)
703{
704	int r;
705
706	if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
707		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
708		    "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
709			return r;
710		}
711	} else
712		debug_f("%s (unrecognised)", name);
713	return 0;
714}
715
716int
717kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
718{
719	struct kex *kex = ssh->kex;
720	const int max_ext_info = kex->server ? 1 : 2;
721	u_int32_t i, ninfo;
722	char *name;
723	u_char *val;
724	size_t vlen;
725	int r;
726
727	debug("SSH2_MSG_EXT_INFO received");
728	if (++kex->ext_info_received > max_ext_info) {
729		error("too many SSH2_MSG_EXT_INFO messages sent by peer");
730		return dispatch_protocol_error(type, seq, ssh);
731	}
732	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
733	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
734		return r;
735	if (ninfo >= 1024) {
736		error("SSH2_MSG_EXT_INFO with too many entries, expected "
737		    "<=1024, received %u", ninfo);
738		return dispatch_protocol_error(type, seq, ssh);
739	}
740	for (i = 0; i < ninfo; i++) {
741		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
742			return r;
743		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
744			free(name);
745			return r;
746		}
747		debug3_f("extension %s", name);
748		if (kex->server) {
749			if ((r = kex_ext_info_server_parse(ssh, name,
750			    val, vlen)) != 0)
751				return r;
752		} else {
753			if ((r = kex_ext_info_client_parse(ssh, name,
754			    val, vlen)) != 0)
755				return r;
756		}
757		free(name);
758		free(val);
759	}
760	return sshpkt_get_end(ssh);
761}
762
763static int
764kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
765{
766	struct kex *kex = ssh->kex;
767	int r;
768
769	debug("SSH2_MSG_NEWKEYS received");
770	if (kex->ext_info_c && (kex->flags & KEX_INITIAL) != 0)
771		ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
772	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
773	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
774	if ((r = sshpkt_get_end(ssh)) != 0)
775		return r;
776	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
777		return r;
778	kex->done = 1;
779	kex->flags &= ~KEX_INITIAL;
780	sshbuf_reset(kex->peer);
781	/* sshbuf_reset(kex->my); */
782	kex->flags &= ~KEX_INIT_SENT;
783	free(kex->name);
784	kex->name = NULL;
785	return 0;
786}
787
788int
789kex_send_kexinit(struct ssh *ssh)
790{
791	u_char *cookie;
792	struct kex *kex = ssh->kex;
793	int r;
794
795	if (kex == NULL) {
796		error_f("no kex");
797		return SSH_ERR_INTERNAL_ERROR;
798	}
799	if (kex->flags & KEX_INIT_SENT)
800		return 0;
801	kex->done = 0;
802
803	/* generate a random cookie */
804	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
805		error_f("bad kex length: %zu < %d",
806		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
807		return SSH_ERR_INVALID_FORMAT;
808	}
809	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
810		error_f("buffer error");
811		return SSH_ERR_INTERNAL_ERROR;
812	}
813	arc4random_buf(cookie, KEX_COOKIE_LEN);
814
815	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
816	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
817	    (r = sshpkt_send(ssh)) != 0) {
818		error_fr(r, "compose reply");
819		return r;
820	}
821	debug("SSH2_MSG_KEXINIT sent");
822	kex->flags |= KEX_INIT_SENT;
823	return 0;
824}
825
826int
827kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
828{
829	struct kex *kex = ssh->kex;
830	const u_char *ptr;
831	u_int i;
832	size_t dlen;
833	int r;
834
835	debug("SSH2_MSG_KEXINIT received");
836	if (kex == NULL) {
837		error_f("no kex");
838		return SSH_ERR_INTERNAL_ERROR;
839	}
840	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
841	ptr = sshpkt_ptr(ssh, &dlen);
842	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
843		return r;
844
845	/* discard packet */
846	for (i = 0; i < KEX_COOKIE_LEN; i++) {
847		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
848			error_fr(r, "discard cookie");
849			return r;
850		}
851	}
852	for (i = 0; i < PROPOSAL_MAX; i++) {
853		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
854			error_fr(r, "discard proposal");
855			return r;
856		}
857	}
858	/*
859	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
860	 * KEX method has the server move first, but a server might be using
861	 * a custom method or one that we otherwise don't support. We should
862	 * be prepared to remember first_kex_follows here so we can eat a
863	 * packet later.
864	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
865	 * for cases where the server *doesn't* go first. I guess we should
866	 * ignore it when it is set for these cases, which is what we do now.
867	 */
868	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
869	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
870	    (r = sshpkt_get_end(ssh)) != 0)
871			return r;
872
873	if (!(kex->flags & KEX_INIT_SENT))
874		if ((r = kex_send_kexinit(ssh)) != 0)
875			return r;
876	if ((r = kex_choose_conf(ssh, seq)) != 0)
877		return r;
878
879	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
880		return (kex->kex[kex->kex_type])(ssh);
881
882	error_f("unknown kex type %u", kex->kex_type);
883	return SSH_ERR_INTERNAL_ERROR;
884}
885
886struct kex *
887kex_new(void)
888{
889	struct kex *kex;
890
891	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
892	    (kex->peer = sshbuf_new()) == NULL ||
893	    (kex->my = sshbuf_new()) == NULL ||
894	    (kex->client_version = sshbuf_new()) == NULL ||
895	    (kex->server_version = sshbuf_new()) == NULL ||
896	    (kex->session_id = sshbuf_new()) == NULL) {
897		kex_free(kex);
898		return NULL;
899	}
900	return kex;
901}
902
903void
904kex_free_newkeys(struct newkeys *newkeys)
905{
906	if (newkeys == NULL)
907		return;
908	if (newkeys->enc.key) {
909		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
910		free(newkeys->enc.key);
911		newkeys->enc.key = NULL;
912	}
913	if (newkeys->enc.iv) {
914		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
915		free(newkeys->enc.iv);
916		newkeys->enc.iv = NULL;
917	}
918	free(newkeys->enc.name);
919	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
920	free(newkeys->comp.name);
921	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
922	mac_clear(&newkeys->mac);
923	if (newkeys->mac.key) {
924		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
925		free(newkeys->mac.key);
926		newkeys->mac.key = NULL;
927	}
928	free(newkeys->mac.name);
929	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
930	freezero(newkeys, sizeof(*newkeys));
931}
932
933void
934kex_free(struct kex *kex)
935{
936	u_int mode;
937
938	if (kex == NULL)
939		return;
940
941#ifdef WITH_OPENSSL
942	DH_free(kex->dh);
943	EC_KEY_free(kex->ec_client_key);
944#endif
945	for (mode = 0; mode < MODE_MAX; mode++) {
946		kex_free_newkeys(kex->newkeys[mode]);
947		kex->newkeys[mode] = NULL;
948	}
949	sshbuf_free(kex->peer);
950	sshbuf_free(kex->my);
951	sshbuf_free(kex->client_version);
952	sshbuf_free(kex->server_version);
953	sshbuf_free(kex->client_pub);
954	sshbuf_free(kex->session_id);
955	sshbuf_free(kex->initial_sig);
956	sshkey_free(kex->initial_hostkey);
957	free(kex->failed_choice);
958	free(kex->hostkey_alg);
959	free(kex->name);
960	free(kex);
961}
962
963int
964kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
965{
966	int r;
967
968	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
969		return r;
970	ssh->kex->flags = KEX_INITIAL;
971	kex_reset_dispatch(ssh);
972	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
973	return 0;
974}
975
976int
977kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
978{
979	int r;
980
981	if ((r = kex_ready(ssh, proposal)) != 0)
982		return r;
983	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
984		kex_free(ssh->kex);
985		ssh->kex = NULL;
986		return r;
987	}
988	return 0;
989}
990
991/*
992 * Request key re-exchange, returns 0 on success or a ssherr.h error
993 * code otherwise. Must not be called if KEX is incomplete or in-progress.
994 */
995int
996kex_start_rekex(struct ssh *ssh)
997{
998	if (ssh->kex == NULL) {
999		error_f("no kex");
1000		return SSH_ERR_INTERNAL_ERROR;
1001	}
1002	if (ssh->kex->done == 0) {
1003		error_f("requested twice");
1004		return SSH_ERR_INTERNAL_ERROR;
1005	}
1006	ssh->kex->done = 0;
1007	return kex_send_kexinit(ssh);
1008}
1009
1010static int
1011choose_enc(struct sshenc *enc, char *client, char *server)
1012{
1013	char *name = match_list(client, server, NULL);
1014
1015	if (name == NULL)
1016		return SSH_ERR_NO_CIPHER_ALG_MATCH;
1017	if ((enc->cipher = cipher_by_name(name)) == NULL) {
1018		error_f("unsupported cipher %s", name);
1019		free(name);
1020		return SSH_ERR_INTERNAL_ERROR;
1021	}
1022	enc->name = name;
1023	enc->enabled = 0;
1024	enc->iv = NULL;
1025	enc->iv_len = cipher_ivlen(enc->cipher);
1026	enc->key = NULL;
1027	enc->key_len = cipher_keylen(enc->cipher);
1028	enc->block_size = cipher_blocksize(enc->cipher);
1029	return 0;
1030}
1031
1032static int
1033choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
1034{
1035	char *name = match_list(client, server, NULL);
1036
1037	if (name == NULL)
1038		return SSH_ERR_NO_MAC_ALG_MATCH;
1039	if (mac_setup(mac, name) < 0) {
1040		error_f("unsupported MAC %s", name);
1041		free(name);
1042		return SSH_ERR_INTERNAL_ERROR;
1043	}
1044	mac->name = name;
1045	mac->key = NULL;
1046	mac->enabled = 0;
1047	return 0;
1048}
1049
1050static int
1051choose_comp(struct sshcomp *comp, char *client, char *server)
1052{
1053	char *name = match_list(client, server, NULL);
1054
1055	if (name == NULL)
1056		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
1057#ifdef WITH_ZLIB
1058	if (strcmp(name, "zlib@openssh.com") == 0) {
1059		comp->type = COMP_DELAYED;
1060	} else if (strcmp(name, "zlib") == 0) {
1061		comp->type = COMP_ZLIB;
1062	} else
1063#endif	/* WITH_ZLIB */
1064	if (strcmp(name, "none") == 0) {
1065		comp->type = COMP_NONE;
1066	} else {
1067		error_f("unsupported compression scheme %s", name);
1068		free(name);
1069		return SSH_ERR_INTERNAL_ERROR;
1070	}
1071	comp->name = name;
1072	return 0;
1073}
1074
1075static int
1076choose_kex(struct kex *k, char *client, char *server)
1077{
1078	const struct kexalg *kexalg;
1079
1080	k->name = match_list(client, server, NULL);
1081
1082	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
1083	if (k->name == NULL)
1084		return SSH_ERR_NO_KEX_ALG_MATCH;
1085	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
1086		error_f("unsupported KEX method %s", k->name);
1087		return SSH_ERR_INTERNAL_ERROR;
1088	}
1089	k->kex_type = kexalg->type;
1090	k->hash_alg = kexalg->hash_alg;
1091	k->ec_nid = kexalg->ec_nid;
1092	return 0;
1093}
1094
1095static int
1096choose_hostkeyalg(struct kex *k, char *client, char *server)
1097{
1098	free(k->hostkey_alg);
1099	k->hostkey_alg = match_list(client, server, NULL);
1100
1101	debug("kex: host key algorithm: %s",
1102	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
1103	if (k->hostkey_alg == NULL)
1104		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
1105	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
1106	if (k->hostkey_type == KEY_UNSPEC) {
1107		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
1108		return SSH_ERR_INTERNAL_ERROR;
1109	}
1110	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
1111	return 0;
1112}
1113
1114static int
1115proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
1116{
1117	static int check[] = {
1118		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
1119	};
1120	int *idx;
1121	char *p;
1122
1123	for (idx = &check[0]; *idx != -1; idx++) {
1124		if ((p = strchr(my[*idx], ',')) != NULL)
1125			*p = '\0';
1126		if ((p = strchr(peer[*idx], ',')) != NULL)
1127			*p = '\0';
1128		if (strcmp(my[*idx], peer[*idx]) != 0) {
1129			debug2("proposal mismatch: my %s peer %s",
1130			    my[*idx], peer[*idx]);
1131			return (0);
1132		}
1133	}
1134	debug2("proposals match");
1135	return (1);
1136}
1137
1138static int
1139kexalgs_contains(char **peer, const char *ext)
1140{
1141	return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1142}
1143
1144static int
1145kex_choose_conf(struct ssh *ssh, uint32_t seq)
1146{
1147	struct kex *kex = ssh->kex;
1148	struct newkeys *newkeys;
1149	char **my = NULL, **peer = NULL;
1150	char **cprop, **sprop;
1151	int nenc, nmac, ncomp;
1152	u_int mode, ctos, need, dh_need, authlen;
1153	int log_flag = 0;
1154	int r, first_kex_follows;
1155
1156	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1157	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1158		goto out;
1159	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1160	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1161		goto out;
1162
1163	if (kex->server) {
1164		cprop=peer;
1165		sprop=my;
1166	} else {
1167		cprop=my;
1168		sprop=peer;
1169	}
1170
1171	/* Check whether peer supports ext_info/kex_strict */
1172	if ((kex->flags & KEX_INITIAL) != 0) {
1173		if (kex->server) {
1174			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1175			kex->kex_strict = kexalgs_contains(peer,
1176			    "kex-strict-c-v00@openssh.com");
1177		} else {
1178			kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
1179			kex->kex_strict = kexalgs_contains(peer,
1180			    "kex-strict-s-v00@openssh.com");
1181		}
1182		if (kex->kex_strict) {
1183			debug3_f("will use strict KEX ordering");
1184			if (seq != 0)
1185				ssh_packet_disconnect(ssh,
1186				    "strict KEX violation: "
1187				    "KEXINIT was not the first packet");
1188		}
1189	}
1190
1191	/* Check whether client supports rsa-sha2 algorithms */
1192	if (kex->server && (kex->flags & KEX_INITIAL)) {
1193		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1194		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1195			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1196		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1197		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1198			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1199	}
1200
1201	/* Algorithm Negotiation */
1202	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1203	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1204		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1205		peer[PROPOSAL_KEX_ALGS] = NULL;
1206		goto out;
1207	}
1208	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1209	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1210		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1211		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1212		goto out;
1213	}
1214	for (mode = 0; mode < MODE_MAX; mode++) {
1215		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1216			r = SSH_ERR_ALLOC_FAIL;
1217			goto out;
1218		}
1219		kex->newkeys[mode] = newkeys;
1220		ctos = (!kex->server && mode == MODE_OUT) ||
1221		    (kex->server && mode == MODE_IN);
1222		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1223		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1224		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1225		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1226		    sprop[nenc])) != 0) {
1227			kex->failed_choice = peer[nenc];
1228			peer[nenc] = NULL;
1229			goto out;
1230		}
1231		authlen = cipher_authlen(newkeys->enc.cipher);
1232		/* ignore mac for authenticated encryption */
1233		if (authlen == 0 &&
1234		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1235		    sprop[nmac])) != 0) {
1236			kex->failed_choice = peer[nmac];
1237			peer[nmac] = NULL;
1238			goto out;
1239		}
1240		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1241		    sprop[ncomp])) != 0) {
1242			kex->failed_choice = peer[ncomp];
1243			peer[ncomp] = NULL;
1244			goto out;
1245		}
1246		debug("REQUESTED ENC.NAME is '%s'", newkeys->enc.name);
1247		if (strcmp(newkeys->enc.name, "none") == 0) {
1248			int auth_flag;
1249
1250			auth_flag = ssh_packet_authentication_state(ssh);
1251			debug("Requesting NONE. Authflag is %d", auth_flag);
1252			if (auth_flag == 1) {
1253				debug("None requested post authentication.");
1254			} else {
1255				fatal("Pre-authentication none cipher requests are not allowed.");
1256			}
1257		}
1258		debug("kex: %s cipher: %s MAC: %s compression: %s",
1259		    ctos ? "client->server" : "server->client",
1260		    newkeys->enc.name,
1261		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1262		    newkeys->comp.name);
1263		/* client starts withctos = 0 && log flag = 0 and no log*/
1264		/* 2nd client pass ctos=1 and flag = 1 so no log*/
1265		/* server starts with ctos =1 && log_flag = 0 so log */
1266		/* 2nd sever pass ctos = 1 && log flag = 1 so no log*/
1267		/* -cjr*/
1268		if (ctos && !log_flag) {
1269			logit("SSH: Server;Ltype: Kex;Remote: %s-%d;Enc: %s;MAC: %s;Comp: %s",
1270			      ssh_remote_ipaddr(ssh),
1271			      ssh_remote_port(ssh),
1272			      newkeys->enc.name,
1273			      newkeys->mac.name,
1274			      newkeys->comp.name);
1275		}
1276		log_flag = 1;
1277	}
1278	need = dh_need = 0;
1279	for (mode = 0; mode < MODE_MAX; mode++) {
1280		newkeys = kex->newkeys[mode];
1281		need = MAXIMUM(need, newkeys->enc.key_len);
1282		need = MAXIMUM(need, newkeys->enc.block_size);
1283		need = MAXIMUM(need, newkeys->enc.iv_len);
1284		need = MAXIMUM(need, newkeys->mac.key_len);
1285		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1286		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1287		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1288		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1289	}
1290	/* XXX need runden? */
1291	kex->we_need = need;
1292	kex->dh_need = dh_need;
1293
1294	/* ignore the next message if the proposals do not match */
1295	if (first_kex_follows && !proposals_match(my, peer))
1296		ssh->dispatch_skip_packets = 1;
1297	r = 0;
1298 out:
1299	kex_prop_free(my);
1300	kex_prop_free(peer);
1301	return r;
1302}
1303
1304static int
1305derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1306    const struct sshbuf *shared_secret, u_char **keyp)
1307{
1308	struct kex *kex = ssh->kex;
1309	struct ssh_digest_ctx *hashctx = NULL;
1310	char c = id;
1311	u_int have;
1312	size_t mdsz;
1313	u_char *digest;
1314	int r;
1315
1316	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1317		return SSH_ERR_INVALID_ARGUMENT;
1318	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1319		r = SSH_ERR_ALLOC_FAIL;
1320		goto out;
1321	}
1322
1323	/* K1 = HASH(K || H || "A" || session_id) */
1324	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1325	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1326	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1327	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1328	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1329	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1330		r = SSH_ERR_LIBCRYPTO_ERROR;
1331		error_f("KEX hash failed");
1332		goto out;
1333	}
1334	ssh_digest_free(hashctx);
1335	hashctx = NULL;
1336
1337	/*
1338	 * expand key:
1339	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1340	 * Key = K1 || K2 || ... || Kn
1341	 */
1342	for (have = mdsz; need > have; have += mdsz) {
1343		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1344		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1345		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1346		    ssh_digest_update(hashctx, digest, have) != 0 ||
1347		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1348			error_f("KDF failed");
1349			r = SSH_ERR_LIBCRYPTO_ERROR;
1350			goto out;
1351		}
1352		ssh_digest_free(hashctx);
1353		hashctx = NULL;
1354	}
1355#ifdef DEBUG_KEX
1356	fprintf(stderr, "key '%c'== ", c);
1357	dump_digest("key", digest, need);
1358#endif
1359	*keyp = digest;
1360	digest = NULL;
1361	r = 0;
1362 out:
1363	free(digest);
1364	ssh_digest_free(hashctx);
1365	return r;
1366}
1367
1368#define NKEYS	6
1369int
1370kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1371    const struct sshbuf *shared_secret)
1372{
1373	struct kex *kex = ssh->kex;
1374	u_char *keys[NKEYS];
1375	u_int i, j, mode, ctos;
1376	int r;
1377
1378	/* save initial hash as session id */
1379	if ((kex->flags & KEX_INITIAL) != 0) {
1380		if (sshbuf_len(kex->session_id) != 0) {
1381			error_f("already have session ID at kex");
1382			return SSH_ERR_INTERNAL_ERROR;
1383		}
1384		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1385			return r;
1386	} else if (sshbuf_len(kex->session_id) == 0) {
1387		error_f("no session ID in rekex");
1388		return SSH_ERR_INTERNAL_ERROR;
1389	}
1390	for (i = 0; i < NKEYS; i++) {
1391		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1392		    shared_secret, &keys[i])) != 0) {
1393			for (j = 0; j < i; j++)
1394				free(keys[j]);
1395			return r;
1396		}
1397	}
1398	for (mode = 0; mode < MODE_MAX; mode++) {
1399		ctos = (!kex->server && mode == MODE_OUT) ||
1400		    (kex->server && mode == MODE_IN);
1401		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1402		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1403		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1404	}
1405	return 0;
1406}
1407
1408int
1409kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1410{
1411	struct kex *kex = ssh->kex;
1412
1413	*pubp = NULL;
1414	*prvp = NULL;
1415	if (kex->load_host_public_key == NULL ||
1416	    kex->load_host_private_key == NULL) {
1417		error_f("missing hostkey loader");
1418		return SSH_ERR_INVALID_ARGUMENT;
1419	}
1420	*pubp = kex->load_host_public_key(kex->hostkey_type,
1421	    kex->hostkey_nid, ssh);
1422	*prvp = kex->load_host_private_key(kex->hostkey_type,
1423	    kex->hostkey_nid, ssh);
1424	if (*pubp == NULL)
1425		return SSH_ERR_NO_HOSTKEY_LOADED;
1426	return 0;
1427}
1428
1429int
1430kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1431{
1432	struct kex *kex = ssh->kex;
1433
1434	if (kex->verify_host_key == NULL) {
1435		error_f("missing hostkey verifier");
1436		return SSH_ERR_INVALID_ARGUMENT;
1437	}
1438	if (server_host_key->type != kex->hostkey_type ||
1439	    (kex->hostkey_type == KEY_ECDSA &&
1440	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1441		return SSH_ERR_KEY_TYPE_MISMATCH;
1442	if (kex->verify_host_key(server_host_key, ssh) == -1)
1443		return  SSH_ERR_SIGNATURE_INVALID;
1444	return 0;
1445}
1446
1447#if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1448void
1449dump_digest(const char *msg, const u_char *digest, int len)
1450{
1451	fprintf(stderr, "%s\n", msg);
1452	sshbuf_dump_data(digest, len, stderr);
1453}
1454#endif
1455
1456/*
1457 * Send a plaintext error message to the peer, suffixed by \r\n.
1458 * Only used during banner exchange, and there only for the server.
1459 */
1460static void
1461send_error(struct ssh *ssh, const char *msg)
1462{
1463	const char *crnl = "\r\n";
1464
1465	if (!ssh->kex->server)
1466		return;
1467
1468	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1469	    __UNCONST(msg), strlen(msg)) != strlen(msg) ||
1470	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1471	    __UNCONST(crnl), strlen(crnl)) != strlen(crnl))
1472		error_f("write: %.100s", strerror(errno));
1473}
1474
1475/*
1476 * Sends our identification string and waits for the peer's. Will block for
1477 * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1478 * Returns on 0 success or a ssherr.h code on failure.
1479 */
1480int
1481kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1482    const char *version_addendum)
1483{
1484	int remote_major, remote_minor, mismatch, oerrno = 0;
1485	size_t len, n;
1486	int r, expect_nl;
1487	u_char c;
1488	struct sshbuf *our_version = ssh->kex->server ?
1489	    ssh->kex->server_version : ssh->kex->client_version;
1490	struct sshbuf *peer_version = ssh->kex->server ?
1491	    ssh->kex->client_version : ssh->kex->server_version;
1492	char *our_version_string = NULL, *peer_version_string = NULL;
1493	char *cp, *remote_version = NULL;
1494
1495	/* Prepare and send our banner */
1496	sshbuf_reset(our_version);
1497	if (version_addendum != NULL && *version_addendum == '\0')
1498		version_addendum = NULL;
1499	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1500	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1501	    version_addendum == NULL ? "" : " ",
1502	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1503		oerrno = errno;
1504		error_fr(r, "sshbuf_putf");
1505		goto out;
1506	}
1507
1508	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1509	    sshbuf_mutable_ptr(our_version),
1510	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1511		oerrno = errno;
1512		debug_f("write: %.100s", strerror(errno));
1513		r = SSH_ERR_SYSTEM_ERROR;
1514		goto out;
1515	}
1516	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1517		oerrno = errno;
1518		error_fr(r, "sshbuf_consume_end");
1519		goto out;
1520	}
1521	our_version_string = sshbuf_dup_string(our_version);
1522	if (our_version_string == NULL) {
1523		error_f("sshbuf_dup_string failed");
1524		r = SSH_ERR_ALLOC_FAIL;
1525		goto out;
1526	}
1527	debug("Local version string %.100s", our_version_string);
1528
1529	/* Read other side's version identification. */
1530	for (n = 0; ; n++) {
1531		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1532			send_error(ssh, "No SSH identification string "
1533			    "received.");
1534			error_f("No SSH version received in first %u lines "
1535			    "from server", SSH_MAX_PRE_BANNER_LINES);
1536			r = SSH_ERR_INVALID_FORMAT;
1537			goto out;
1538		}
1539		sshbuf_reset(peer_version);
1540		expect_nl = 0;
1541		for (;;) {
1542			if (timeout_ms > 0) {
1543				r = waitrfd(ssh_packet_get_connection_in(ssh),
1544				    &timeout_ms, NULL);
1545				if (r == -1 && errno == ETIMEDOUT) {
1546					send_error(ssh, "Timed out waiting "
1547					    "for SSH identification string.");
1548					error("Connection timed out during "
1549					    "banner exchange");
1550					r = SSH_ERR_CONN_TIMEOUT;
1551					goto out;
1552				} else if (r == -1) {
1553					oerrno = errno;
1554					error_f("%s", strerror(errno));
1555					r = SSH_ERR_SYSTEM_ERROR;
1556					goto out;
1557				}
1558			}
1559
1560			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1561			    &c, 1);
1562			if (len != 1 && errno == EPIPE) {
1563				verbose_f("Connection closed by remote host");
1564				r = SSH_ERR_CONN_CLOSED;
1565				goto out;
1566			} else if (len != 1) {
1567				oerrno = errno;
1568				error_f("read: %.100s", strerror(errno));
1569				r = SSH_ERR_SYSTEM_ERROR;
1570				goto out;
1571			}
1572			if (c == '\r') {
1573				expect_nl = 1;
1574				continue;
1575			}
1576			if (c == '\n')
1577				break;
1578			if (c == '\0' || expect_nl) {
1579				verbose_f("banner line contains invalid "
1580				    "characters");
1581				goto invalid;
1582			}
1583			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1584				oerrno = errno;
1585				error_fr(r, "sshbuf_put");
1586				goto out;
1587			}
1588			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1589				verbose_f("banner line too long");
1590				goto invalid;
1591			}
1592		}
1593		/* Is this an actual protocol banner? */
1594		if (sshbuf_len(peer_version) > 4 &&
1595		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1596			break;
1597		/* If not, then just log the line and continue */
1598		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1599			error_f("sshbuf_dup_string failed");
1600			r = SSH_ERR_ALLOC_FAIL;
1601			goto out;
1602		}
1603		/* Do not accept lines before the SSH ident from a client */
1604		if (ssh->kex->server) {
1605			verbose_f("client sent invalid protocol identifier "
1606			    "\"%.256s\"", cp);
1607			free(cp);
1608			goto invalid;
1609		}
1610		debug_f("banner line %zu: %s", n, cp);
1611		free(cp);
1612	}
1613	peer_version_string = sshbuf_dup_string(peer_version);
1614	if (peer_version_string == NULL)
1615		fatal_f("sshbuf_dup_string failed");
1616	/* XXX must be same size for sscanf */
1617	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1618		error_f("calloc failed");
1619		r = SSH_ERR_ALLOC_FAIL;
1620		goto out;
1621	}
1622
1623	/*
1624	 * Check that the versions match.  In future this might accept
1625	 * several versions and set appropriate flags to handle them.
1626	 */
1627	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1628	    &remote_major, &remote_minor, remote_version) != 3) {
1629		error("Bad remote protocol version identification: '%.100s'",
1630		    peer_version_string);
1631 invalid:
1632		send_error(ssh, "Invalid SSH identification string.");
1633		r = SSH_ERR_INVALID_FORMAT;
1634		goto out;
1635	}
1636	debug("Remote protocol version %d.%d, remote software version %.100s",
1637	    remote_major, remote_minor, remote_version);
1638	compat_banner(ssh, remote_version);
1639
1640	mismatch = 0;
1641	switch (remote_major) {
1642	case 2:
1643		break;
1644	case 1:
1645		if (remote_minor != 99)
1646			mismatch = 1;
1647		break;
1648	default:
1649		mismatch = 1;
1650		break;
1651	}
1652	if (mismatch) {
1653		error("Protocol major versions differ: %d vs. %d",
1654		    PROTOCOL_MAJOR_2, remote_major);
1655		send_error(ssh, "Protocol major versions differ.");
1656		r = SSH_ERR_NO_PROTOCOL_VERSION;
1657		goto out;
1658	}
1659
1660	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1661		logit("probed from %s port %d with %s.  Don't panic.",
1662		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1663		    peer_version_string);
1664		r = SSH_ERR_CONN_CLOSED; /* XXX */
1665		goto out;
1666	}
1667	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1668		logit("scanned from %s port %d with %s.  Don't panic.",
1669		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1670		    peer_version_string);
1671		r = SSH_ERR_CONN_CLOSED; /* XXX */
1672		goto out;
1673	}
1674	/* success */
1675	r = 0;
1676 out:
1677	free(our_version_string);
1678	free(peer_version_string);
1679	free(remote_version);
1680	if (r == SSH_ERR_SYSTEM_ERROR)
1681		errno = oerrno;
1682	return r;
1683}
1684
1685