1/*	$NetBSD: sshkey-xmss.c,v 1.10 2023/08/03 07:59:32 mrg Exp $	*/
2/* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
3/*
4 * Copyright (c) 2017 Markus Friedl.  All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 *    notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 *    notice, this list of conditions and the following disclaimer in the
13 *    documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
17 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
18 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
19 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
20 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
24 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26#include "includes.h"
27__RCSID("$NetBSD: sshkey-xmss.c,v 1.10 2023/08/03 07:59:32 mrg Exp $");
28
29#include <sys/types.h>
30#include <sys/uio.h>
31
32#include <stdio.h>
33#include <string.h>
34#include <unistd.h>
35#include <fcntl.h>
36#include <errno.h>
37
38#include "ssh2.h"
39#include "ssherr.h"
40#include "sshbuf.h"
41#include "cipher.h"
42#include "sshkey.h"
43#include "sshkey-xmss.h"
44#include "atomicio.h"
45#include "log.h"
46
47#include "xmss_fast.h"
48
49/* opaque internal XMSS state */
50#define XMSS_MAGIC		"xmss-state-v1"
51#define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
52struct ssh_xmss_state {
53	xmss_params	params;
54	u_int32_t	n, w, h, k;
55
56	bds_state	bds;
57	u_char		*stack;
58	u_int32_t	stackoffset;
59	u_char		*stacklevels;
60	u_char		*auth;
61	u_char		*keep;
62	u_char		*th_nodes;
63	u_char		*retain;
64	treehash_inst	*treehash;
65
66	u_int32_t	idx;		/* state read from file */
67	u_int32_t	maxidx;		/* restricted # of signatures */
68	int		have_state;	/* .state file exists */
69	int		lockfd;		/* locked in sshkey_xmss_get_state() */
70	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
71	char		*enc_ciphername;/* encrypt state with cipher */
72	u_char		*enc_keyiv;	/* encrypt state with key */
73	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
74};
75
76int	 sshkey_xmss_init_bds_state(struct sshkey *);
77int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
78void	 sshkey_xmss_free_bds(struct sshkey *);
79int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
80	    int *, int);
81int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
82	    struct sshbuf **);
83int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
84	    struct sshbuf **);
85int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
86int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
87
88#define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
89    0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
90
91int
92sshkey_xmss_init(struct sshkey *key, const char *name)
93{
94	struct ssh_xmss_state *state;
95
96	if (key->xmss_state != NULL)
97		return SSH_ERR_INVALID_FORMAT;
98	if (name == NULL)
99		return SSH_ERR_INVALID_FORMAT;
100	state = calloc(sizeof(struct ssh_xmss_state), 1);
101	if (state == NULL)
102		return SSH_ERR_ALLOC_FAIL;
103	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
104		state->n = 32;
105		state->w = 16;
106		state->h = 10;
107	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
108		state->n = 32;
109		state->w = 16;
110		state->h = 16;
111	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
112		state->n = 32;
113		state->w = 16;
114		state->h = 20;
115	} else {
116		free(state);
117		return SSH_ERR_KEY_TYPE_UNKNOWN;
118	}
119	if ((key->xmss_name = strdup(name)) == NULL) {
120		free(state);
121		return SSH_ERR_ALLOC_FAIL;
122	}
123	state->k = 2;	/* XXX hardcoded */
124	state->lockfd = -1;
125	if (xmss_set_params(&state->params, state->n, state->h, state->w,
126	    state->k) != 0) {
127		free(state);
128		return SSH_ERR_INVALID_FORMAT;
129	}
130	key->xmss_state = state;
131	return 0;
132}
133
134void
135sshkey_xmss_free_state(struct sshkey *key)
136{
137	struct ssh_xmss_state *state = key->xmss_state;
138
139	sshkey_xmss_free_bds(key);
140	if (state) {
141		if (state->enc_keyiv) {
142			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
143			free(state->enc_keyiv);
144		}
145		free(state->enc_ciphername);
146		free(state);
147	}
148	key->xmss_state = NULL;
149}
150
151#define SSH_XMSS_K2_MAGIC	"k=2"
152#define num_stack(x)		((x->h+1)*(x->n))
153#define num_stacklevels(x)	(x->h+1)
154#define num_auth(x)		((x->h)*(x->n))
155#define num_keep(x)		((x->h >> 1)*(x->n))
156#define num_th_nodes(x)		((x->h - x->k)*(x->n))
157#define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
158#define num_treehash(x)		((x->h) - (x->k))
159
160int
161sshkey_xmss_init_bds_state(struct sshkey *key)
162{
163	struct ssh_xmss_state *state = key->xmss_state;
164	u_int32_t i;
165
166	state->stackoffset = 0;
167	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
168	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
169	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
170	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
171	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
172	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
173	    (state->treehash = calloc(num_treehash(state),
174	    sizeof(treehash_inst))) == NULL) {
175		sshkey_xmss_free_bds(key);
176		return SSH_ERR_ALLOC_FAIL;
177	}
178	for (i = 0; i < state->h - state->k; i++)
179		state->treehash[i].node = &state->th_nodes[state->n*i];
180	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
181	    state->stacklevels, state->auth, state->keep, state->treehash,
182	    state->retain, 0);
183	return 0;
184}
185
186void
187sshkey_xmss_free_bds(struct sshkey *key)
188{
189	struct ssh_xmss_state *state = key->xmss_state;
190
191	if (state == NULL)
192		return;
193	free(state->stack);
194	free(state->stacklevels);
195	free(state->auth);
196	free(state->keep);
197	free(state->th_nodes);
198	free(state->retain);
199	free(state->treehash);
200	state->stack = NULL;
201	state->stacklevels = NULL;
202	state->auth = NULL;
203	state->keep = NULL;
204	state->th_nodes = NULL;
205	state->retain = NULL;
206	state->treehash = NULL;
207}
208
209void *
210sshkey_xmss_params(const struct sshkey *key)
211{
212	struct ssh_xmss_state *state = key->xmss_state;
213
214	if (state == NULL)
215		return NULL;
216	return &state->params;
217}
218
219void *
220sshkey_xmss_bds_state(const struct sshkey *key)
221{
222	struct ssh_xmss_state *state = key->xmss_state;
223
224	if (state == NULL)
225		return NULL;
226	return &state->bds;
227}
228
229int
230sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
231{
232	struct ssh_xmss_state *state = key->xmss_state;
233
234	if (lenp == NULL)
235		return SSH_ERR_INVALID_ARGUMENT;
236	if (state == NULL)
237		return SSH_ERR_INVALID_FORMAT;
238	*lenp = 4 + state->n +
239	    state->params.wots_par.keysize +
240	    state->h * state->n;
241	return 0;
242}
243
244size_t
245sshkey_xmss_pklen(const struct sshkey *key)
246{
247	struct ssh_xmss_state *state = key->xmss_state;
248
249	if (state == NULL)
250		return 0;
251	return state->n * 2;
252}
253
254size_t
255sshkey_xmss_sklen(const struct sshkey *key)
256{
257	struct ssh_xmss_state *state = key->xmss_state;
258
259	if (state == NULL)
260		return 0;
261	return state->n * 4 + 4;
262}
263
264int
265sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
266{
267	struct ssh_xmss_state *state = k->xmss_state;
268	const struct sshcipher *cipher;
269	size_t keylen = 0, ivlen = 0;
270
271	if (state == NULL)
272		return SSH_ERR_INVALID_ARGUMENT;
273	if ((cipher = cipher_by_name(ciphername)) == NULL)
274		return SSH_ERR_INTERNAL_ERROR;
275	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
276		return SSH_ERR_ALLOC_FAIL;
277	keylen = cipher_keylen(cipher);
278	ivlen = cipher_ivlen(cipher);
279	state->enc_keyiv_len = keylen + ivlen;
280	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
281		free(state->enc_ciphername);
282		state->enc_ciphername = NULL;
283		return SSH_ERR_ALLOC_FAIL;
284	}
285	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
286	return 0;
287}
288
289int
290sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
291{
292	struct ssh_xmss_state *state = k->xmss_state;
293	int r;
294
295	if (state == NULL || state->enc_keyiv == NULL ||
296	    state->enc_ciphername == NULL)
297		return SSH_ERR_INVALID_ARGUMENT;
298	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
299	    (r = sshbuf_put_string(b, state->enc_keyiv,
300	    state->enc_keyiv_len)) != 0)
301		return r;
302	return 0;
303}
304
305int
306sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
307{
308	struct ssh_xmss_state *state = k->xmss_state;
309	size_t len;
310	int r;
311
312	if (state == NULL)
313		return SSH_ERR_INVALID_ARGUMENT;
314	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
315	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
316		return r;
317	state->enc_keyiv_len = len;
318	return 0;
319}
320
321int
322sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
323    enum sshkey_serialize_rep opts)
324{
325	struct ssh_xmss_state *state = k->xmss_state;
326	u_char have_info = 1;
327	u_int32_t idx;
328	int r;
329
330	if (state == NULL)
331		return SSH_ERR_INVALID_ARGUMENT;
332	if (opts != SSHKEY_SERIALIZE_INFO)
333		return 0;
334	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
335	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
336	    (r = sshbuf_put_u32(b, idx)) != 0 ||
337	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
338		return r;
339	return 0;
340}
341
342int
343sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
344{
345	struct ssh_xmss_state *state = k->xmss_state;
346	u_char have_info;
347	int r;
348
349	if (state == NULL)
350		return SSH_ERR_INVALID_ARGUMENT;
351	/* optional */
352	if (sshbuf_len(b) == 0)
353		return 0;
354	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
355		return r;
356	if (have_info != 1)
357		return SSH_ERR_INVALID_ARGUMENT;
358	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
359	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
360		return r;
361	return 0;
362}
363
364int
365sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
366{
367	int r;
368	const char *name;
369
370	if (bits == 10) {
371		name = XMSS_SHA2_256_W16_H10_NAME;
372	} else if (bits == 16) {
373		name = XMSS_SHA2_256_W16_H16_NAME;
374	} else if (bits == 20) {
375		name = XMSS_SHA2_256_W16_H20_NAME;
376	} else {
377		name = XMSS_DEFAULT_NAME;
378	}
379	if ((r = sshkey_xmss_init(k, name)) != 0 ||
380	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
381	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
382		return r;
383	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
384	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
385		return SSH_ERR_ALLOC_FAIL;
386	}
387	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
388	    sshkey_xmss_params(k));
389	return 0;
390}
391
392int
393sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
394    int *have_file, int printerror)
395{
396	struct sshbuf *b = NULL, *enc = NULL;
397	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
398	u_int32_t len;
399	unsigned char buf[4], *data = NULL;
400
401	*have_file = 0;
402	if ((fd = open(filename, O_RDONLY)) >= 0) {
403		*have_file = 1;
404		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
405			PRINT("corrupt state file: %s", filename);
406			goto done;
407		}
408		len = PEEK_U32(buf);
409		if ((data = calloc(len, 1)) == NULL) {
410			ret = SSH_ERR_ALLOC_FAIL;
411			goto done;
412		}
413		if (atomicio(read, fd, data, len) != len) {
414			PRINT("cannot read blob: %s", filename);
415			goto done;
416		}
417		if ((enc = sshbuf_from(data, len)) == NULL) {
418			ret = SSH_ERR_ALLOC_FAIL;
419			goto done;
420		}
421		sshkey_xmss_free_bds(k);
422		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
423			ret = r;
424			goto done;
425		}
426		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
427			ret = r;
428			goto done;
429		}
430		ret = 0;
431	}
432done:
433	if (fd != -1)
434		close(fd);
435	free(data);
436	sshbuf_free(enc);
437	sshbuf_free(b);
438	return ret;
439}
440
441int
442sshkey_xmss_get_state(const struct sshkey *k, int printerror)
443{
444	struct ssh_xmss_state *state = k->xmss_state;
445	u_int32_t idx = 0;
446	char *filename = NULL;
447	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
448	int lockfd = -1, have_state = 0, have_ostate = 0, tries = 0;
449	int ret = SSH_ERR_INVALID_ARGUMENT, r;
450
451	if (state == NULL)
452		goto done;
453	/*
454	 * If maxidx is set, then we are allowed a limited number
455	 * of signatures, but don't need to access the disk.
456	 * Otherwise we need to deal with the on-disk state.
457	 */
458	if (state->maxidx) {
459		/* xmss_sk always contains the current state */
460		idx = PEEK_U32(k->xmss_sk);
461		if (idx < state->maxidx) {
462			state->allow_update = 1;
463			return 0;
464		}
465		return SSH_ERR_INVALID_ARGUMENT;
466	}
467	if ((filename = k->xmss_filename) == NULL)
468		goto done;
469	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
470	    asprintf(&statefile, "%s.state", filename) == -1 ||
471	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
472		ret = SSH_ERR_ALLOC_FAIL;
473		goto done;
474	}
475	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
476		ret = SSH_ERR_SYSTEM_ERROR;
477		PRINT("cannot open/create: %s", lockfile);
478		goto done;
479	}
480	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
481		if (errno != EWOULDBLOCK) {
482			ret = SSH_ERR_SYSTEM_ERROR;
483			PRINT("cannot lock: %s", lockfile);
484			goto done;
485		}
486		if (++tries > 10) {
487			ret = SSH_ERR_SYSTEM_ERROR;
488			PRINT("giving up on: %s", lockfile);
489			goto done;
490		}
491		usleep(1000*100*tries);
492	}
493	/* XXX no longer const */
494	if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
495	    statefile, &have_state, printerror)) != 0) {
496		if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
497		    ostatefile, &have_ostate, printerror)) == 0) {
498			state->allow_update = 1;
499			r = sshkey_xmss_forward_state(k, 1);
500			state->idx = PEEK_U32(k->xmss_sk);
501			state->allow_update = 0;
502		}
503	}
504	if (!have_state && !have_ostate) {
505		/* check that bds state is initialized */
506		if (state->bds.auth == NULL)
507			goto done;
508		PRINT("start from scratch idx 0: %u", state->idx);
509	} else if (r != 0) {
510		ret = r;
511		goto done;
512	}
513	if (state->idx + 1 < state->idx) {
514		PRINT("state wrap: %u", state->idx);
515		goto done;
516	}
517	state->have_state = have_state;
518	state->lockfd = lockfd;
519	state->allow_update = 1;
520	lockfd = -1;
521	ret = 0;
522done:
523	if (lockfd != -1)
524		close(lockfd);
525	free(lockfile);
526	free(statefile);
527	free(ostatefile);
528	return ret;
529}
530
531int
532sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
533{
534	struct ssh_xmss_state *state = k->xmss_state;
535	u_char *sig = NULL;
536	size_t required_siglen;
537	unsigned long long smlen;
538	u_char data;
539	int ret, r;
540
541	if (state == NULL || !state->allow_update)
542		return SSH_ERR_INVALID_ARGUMENT;
543	if (reserve == 0)
544		return SSH_ERR_INVALID_ARGUMENT;
545	if (state->idx + reserve <= state->idx)
546		return SSH_ERR_INVALID_ARGUMENT;
547	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
548		return r;
549	if ((sig = malloc(required_siglen)) == NULL)
550		return SSH_ERR_ALLOC_FAIL;
551	while (reserve-- > 0) {
552		state->idx = PEEK_U32(k->xmss_sk);
553		smlen = required_siglen;
554		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
555		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
556			r = SSH_ERR_INVALID_ARGUMENT;
557			break;
558		}
559	}
560	free(sig);
561	return r;
562}
563
564int
565sshkey_xmss_update_state(const struct sshkey *k, int printerror)
566{
567	struct ssh_xmss_state *state = k->xmss_state;
568	struct sshbuf *b = NULL, *enc = NULL;
569	u_int32_t idx = 0;
570	unsigned char buf[4];
571	char *filename = NULL;
572	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
573	int fd = -1;
574	int ret = SSH_ERR_INVALID_ARGUMENT;
575
576	if (state == NULL || !state->allow_update)
577		return ret;
578	if (state->maxidx) {
579		/* no update since the number of signatures is limited */
580		ret = 0;
581		goto done;
582	}
583	idx = PEEK_U32(k->xmss_sk);
584	if (idx == state->idx) {
585		/* no signature happened, no need to update */
586		ret = 0;
587		goto done;
588	} else if (idx != state->idx + 1) {
589		PRINT("more than one signature happened: idx %u state %u",
590		    idx, state->idx);
591		goto done;
592	}
593	state->idx = idx;
594	if ((filename = k->xmss_filename) == NULL)
595		goto done;
596	if (asprintf(&statefile, "%s.state", filename) == -1 ||
597	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
598	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
599		ret = SSH_ERR_ALLOC_FAIL;
600		goto done;
601	}
602	unlink(nstatefile);
603	if ((b = sshbuf_new()) == NULL) {
604		ret = SSH_ERR_ALLOC_FAIL;
605		goto done;
606	}
607	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
608		PRINT("SERLIALIZE FAILED: %d", ret);
609		goto done;
610	}
611	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
612		PRINT("ENCRYPT FAILED: %d", ret);
613		goto done;
614	}
615	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
616		ret = SSH_ERR_SYSTEM_ERROR;
617		PRINT("open new state file: %s", nstatefile);
618		goto done;
619	}
620	POKE_U32(buf, sshbuf_len(enc));
621	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
622		ret = SSH_ERR_SYSTEM_ERROR;
623		PRINT("write new state file hdr: %s", nstatefile);
624		close(fd);
625		goto done;
626	}
627	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
628	    sshbuf_len(enc)) {
629		ret = SSH_ERR_SYSTEM_ERROR;
630		PRINT("write new state file data: %s", nstatefile);
631		close(fd);
632		goto done;
633	}
634	if (fsync(fd) == -1) {
635		ret = SSH_ERR_SYSTEM_ERROR;
636		PRINT("sync new state file: %s", nstatefile);
637		close(fd);
638		goto done;
639	}
640	if (close(fd) == -1) {
641		ret = SSH_ERR_SYSTEM_ERROR;
642		PRINT("close new state file: %s", nstatefile);
643		goto done;
644	}
645	if (state->have_state) {
646		unlink(ostatefile);
647		if (link(statefile, ostatefile)) {
648			ret = SSH_ERR_SYSTEM_ERROR;
649			PRINT("backup state %s to %s", statefile, ostatefile);
650			goto done;
651		}
652	}
653	if (rename(nstatefile, statefile) == -1) {
654		ret = SSH_ERR_SYSTEM_ERROR;
655		PRINT("rename %s to %s", nstatefile, statefile);
656		goto done;
657	}
658	ret = 0;
659done:
660	if (state->lockfd != -1) {
661		close(state->lockfd);
662		state->lockfd = -1;
663	}
664	if (nstatefile)
665		unlink(nstatefile);
666	free(statefile);
667	free(ostatefile);
668	free(nstatefile);
669	sshbuf_free(b);
670	sshbuf_free(enc);
671	return ret;
672}
673
674int
675sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
676{
677	struct ssh_xmss_state *state = k->xmss_state;
678	treehash_inst *th;
679	u_int32_t i, node;
680	int r;
681
682	if (state == NULL)
683		return SSH_ERR_INVALID_ARGUMENT;
684	if (state->stack == NULL)
685		return SSH_ERR_INVALID_ARGUMENT;
686	state->stackoffset = state->bds.stackoffset;	/* copy back */
687	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
688	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
689	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
690	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
691	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
692	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
693	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
694	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
695	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
696	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
697		return r;
698	for (i = 0; i < num_treehash(state); i++) {
699		th = &state->treehash[i];
700		node = th->node - state->th_nodes;
701		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
702		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
703		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
704		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
705		    (r = sshbuf_put_u32(b, node)) != 0)
706			return r;
707	}
708	return 0;
709}
710
711int
712sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
713    enum sshkey_serialize_rep opts)
714{
715	struct ssh_xmss_state *state = k->xmss_state;
716	int r = SSH_ERR_INVALID_ARGUMENT;
717	u_char have_stack, have_filename, have_enc;
718
719	if (state == NULL)
720		return SSH_ERR_INVALID_ARGUMENT;
721	if ((r = sshbuf_put_u8(b, opts)) != 0)
722		return r;
723	switch (opts) {
724	case SSHKEY_SERIALIZE_STATE:
725		r = sshkey_xmss_serialize_state(k, b);
726		break;
727	case SSHKEY_SERIALIZE_FULL:
728		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
729			return r;
730		r = sshkey_xmss_serialize_state(k, b);
731		break;
732	case SSHKEY_SERIALIZE_SHIELD:
733		/* all of stack/filename/enc are optional */
734		have_stack = state->stack != NULL;
735		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
736			return r;
737		if (have_stack) {
738			state->idx = PEEK_U32(k->xmss_sk);	/* update */
739			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
740				return r;
741		}
742		have_filename = k->xmss_filename != NULL;
743		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
744			return r;
745		if (have_filename &&
746		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
747			return r;
748		have_enc = state->enc_keyiv != NULL;
749		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
750			return r;
751		if (have_enc &&
752		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
753			return r;
754		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
755		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
756			return r;
757		break;
758	case SSHKEY_SERIALIZE_DEFAULT:
759		r = 0;
760		break;
761	default:
762		r = SSH_ERR_INVALID_ARGUMENT;
763		break;
764	}
765	return r;
766}
767
768int
769sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
770{
771	struct ssh_xmss_state *state = k->xmss_state;
772	treehash_inst *th;
773	u_int32_t i, lh, node;
774	size_t ls, lsl, la, lk, ln, lr;
775	char *magic;
776	int r = SSH_ERR_INTERNAL_ERROR;
777
778	if (state == NULL)
779		return SSH_ERR_INVALID_ARGUMENT;
780	if (k->xmss_sk == NULL)
781		return SSH_ERR_INVALID_ARGUMENT;
782	if ((state->treehash = calloc(num_treehash(state),
783	    sizeof(treehash_inst))) == NULL)
784		return SSH_ERR_ALLOC_FAIL;
785	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
786	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
787	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
788	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
789	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
790	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
791	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
792	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
793	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
794	    (r = sshbuf_get_u32(b, &lh)) != 0)
795		goto out;
796	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
797		r = SSH_ERR_INVALID_ARGUMENT;
798		goto out;
799	}
800	/* XXX check stackoffset */
801	if (ls != num_stack(state) ||
802	    lsl != num_stacklevels(state) ||
803	    la != num_auth(state) ||
804	    lk != num_keep(state) ||
805	    ln != num_th_nodes(state) ||
806	    lr != num_retain(state) ||
807	    lh != num_treehash(state)) {
808		r = SSH_ERR_INVALID_ARGUMENT;
809		goto out;
810	}
811	for (i = 0; i < num_treehash(state); i++) {
812		th = &state->treehash[i];
813		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
814		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
815		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
816		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
817		    (r = sshbuf_get_u32(b, &node)) != 0)
818			goto out;
819		if (node < num_th_nodes(state))
820			th->node = &state->th_nodes[node];
821	}
822	POKE_U32(k->xmss_sk, state->idx);
823	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
824	    state->stacklevels, state->auth, state->keep, state->treehash,
825	    state->retain, 0);
826	/* success */
827	r = 0;
828 out:
829	free(magic);
830	return r;
831}
832
833int
834sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
835{
836	struct ssh_xmss_state *state = k->xmss_state;
837	enum sshkey_serialize_rep opts;
838	u_char have_state, have_stack, have_filename, have_enc;
839	int r;
840
841	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
842		return r;
843
844	opts = have_state;
845	switch (opts) {
846	case SSHKEY_SERIALIZE_DEFAULT:
847		r = 0;
848		break;
849	case SSHKEY_SERIALIZE_SHIELD:
850		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
851			return r;
852		if (have_stack &&
853		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
854			return r;
855		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
856			return r;
857		if (have_filename &&
858		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
859			return r;
860		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
861			return r;
862		if (have_enc &&
863		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
864			return r;
865		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
866		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
867			return r;
868		break;
869	case SSHKEY_SERIALIZE_STATE:
870		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
871			return r;
872		break;
873	case SSHKEY_SERIALIZE_FULL:
874		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
875		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
876			return r;
877		break;
878	default:
879		r = SSH_ERR_INVALID_FORMAT;
880		break;
881	}
882	return r;
883}
884
885int
886sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
887   struct sshbuf **retp)
888{
889	struct ssh_xmss_state *state = k->xmss_state;
890	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
891	struct sshcipher_ctx *ciphercontext = NULL;
892	const struct sshcipher *cipher;
893	u_char *cp, *key, *iv = NULL;
894	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
895	int r = SSH_ERR_INTERNAL_ERROR;
896
897	if (retp != NULL)
898		*retp = NULL;
899	if (state == NULL ||
900	    state->enc_keyiv == NULL ||
901	    state->enc_ciphername == NULL)
902		return SSH_ERR_INTERNAL_ERROR;
903	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
904		r = SSH_ERR_INTERNAL_ERROR;
905		goto out;
906	}
907	blocksize = cipher_blocksize(cipher);
908	keylen = cipher_keylen(cipher);
909	ivlen = cipher_ivlen(cipher);
910	authlen = cipher_authlen(cipher);
911	if (state->enc_keyiv_len != keylen + ivlen) {
912		r = SSH_ERR_INVALID_FORMAT;
913		goto out;
914	}
915	key = state->enc_keyiv;
916	if ((encrypted = sshbuf_new()) == NULL ||
917	    (encoded = sshbuf_new()) == NULL ||
918	    (padded = sshbuf_new()) == NULL ||
919	    (iv = malloc(ivlen)) == NULL) {
920		r = SSH_ERR_ALLOC_FAIL;
921		goto out;
922	}
923
924	/* replace first 4 bytes of IV with index to ensure uniqueness */
925	memcpy(iv, key + keylen, ivlen);
926	POKE_U32(iv, state->idx);
927
928	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
929	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
930		goto out;
931
932	/* padded state will be encrypted */
933	if ((r = sshbuf_putb(padded, b)) != 0)
934		goto out;
935	i = 0;
936	while (sshbuf_len(padded) % blocksize) {
937		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
938			goto out;
939	}
940	encrypted_len = sshbuf_len(padded);
941
942	/* header including the length of state is used as AAD */
943	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
944		goto out;
945	aadlen = sshbuf_len(encoded);
946
947	/* concat header and state */
948	if ((r = sshbuf_putb(encoded, padded)) != 0)
949		goto out;
950
951	/* reserve space for encryption of encoded data plus auth tag */
952	/* encrypt at offset addlen */
953	if ((r = sshbuf_reserve(encrypted,
954	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
955	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
956	    iv, ivlen, 1)) != 0 ||
957	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
958	    encrypted_len, aadlen, authlen)) != 0)
959		goto out;
960
961	/* success */
962	r = 0;
963 out:
964	if (retp != NULL) {
965		*retp = encrypted;
966		encrypted = NULL;
967	}
968	sshbuf_free(padded);
969	sshbuf_free(encoded);
970	sshbuf_free(encrypted);
971	cipher_free(ciphercontext);
972	free(iv);
973	return r;
974}
975
976int
977sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
978   struct sshbuf **retp)
979{
980	struct ssh_xmss_state *state = k->xmss_state;
981	struct sshbuf *copy = NULL, *decrypted = NULL;
982	struct sshcipher_ctx *ciphercontext = NULL;
983	const struct sshcipher *cipher = NULL;
984	u_char *key, *iv = NULL, *dp;
985	size_t keylen, ivlen, authlen, aadlen;
986	u_int blocksize, encrypted_len, index;
987	int r = SSH_ERR_INTERNAL_ERROR;
988
989	if (retp != NULL)
990		*retp = NULL;
991	if (state == NULL ||
992	    state->enc_keyiv == NULL ||
993	    state->enc_ciphername == NULL)
994		return SSH_ERR_INTERNAL_ERROR;
995	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
996		r = SSH_ERR_INVALID_FORMAT;
997		goto out;
998	}
999	blocksize = cipher_blocksize(cipher);
1000	keylen = cipher_keylen(cipher);
1001	ivlen = cipher_ivlen(cipher);
1002	authlen = cipher_authlen(cipher);
1003	if (state->enc_keyiv_len != keylen + ivlen) {
1004		r = SSH_ERR_INTERNAL_ERROR;
1005		goto out;
1006	}
1007	key = state->enc_keyiv;
1008
1009	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1010	    (decrypted = sshbuf_new()) == NULL ||
1011	    (iv = malloc(ivlen)) == NULL) {
1012		r = SSH_ERR_ALLOC_FAIL;
1013		goto out;
1014	}
1015
1016	/* check magic */
1017	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1018	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1019		r = SSH_ERR_INVALID_FORMAT;
1020		goto out;
1021	}
1022	/* parse public portion */
1023	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1024	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1025	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1026		goto out;
1027
1028	/* check size of encrypted key blob */
1029	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1030		r = SSH_ERR_INVALID_FORMAT;
1031		goto out;
1032	}
1033	/* check that an appropriate amount of auth data is present */
1034	if (sshbuf_len(encoded) < authlen ||
1035	    sshbuf_len(encoded) - authlen < encrypted_len) {
1036		r = SSH_ERR_INVALID_FORMAT;
1037		goto out;
1038	}
1039
1040	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1041
1042	/* replace first 4 bytes of IV with index to ensure uniqueness */
1043	memcpy(iv, key + keylen, ivlen);
1044	POKE_U32(iv, index);
1045
1046	/* decrypt private state of key */
1047	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1048	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1049	    iv, ivlen, 0)) != 0 ||
1050	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1051	    encrypted_len, aadlen, authlen)) != 0)
1052		goto out;
1053
1054	/* there should be no trailing data */
1055	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1056		goto out;
1057	if (sshbuf_len(encoded) != 0) {
1058		r = SSH_ERR_INVALID_FORMAT;
1059		goto out;
1060	}
1061
1062	/* remove AAD */
1063	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1064		goto out;
1065	/* XXX encrypted includes unchecked padding */
1066
1067	/* success */
1068	r = 0;
1069	if (retp != NULL) {
1070		*retp = decrypted;
1071		decrypted = NULL;
1072	}
1073 out:
1074	cipher_free(ciphercontext);
1075	sshbuf_free(copy);
1076	sshbuf_free(decrypted);
1077	free(iv);
1078	return r;
1079}
1080
1081u_int32_t
1082sshkey_xmss_signatures_left(const struct sshkey *k)
1083{
1084	struct ssh_xmss_state *state = k->xmss_state;
1085	u_int32_t idx;
1086
1087	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1088	    state->maxidx) {
1089		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1090		if (idx < state->maxidx)
1091			return state->maxidx - idx;
1092	}
1093	return 0;
1094}
1095
1096int
1097sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1098{
1099	struct ssh_xmss_state *state = k->xmss_state;
1100
1101	if (sshkey_type_plain(k->type) != KEY_XMSS)
1102		return SSH_ERR_INVALID_ARGUMENT;
1103	if (maxsign == 0)
1104		return 0;
1105	if (state->idx + maxsign < state->idx)
1106		return SSH_ERR_INVALID_ARGUMENT;
1107	state->maxidx = state->idx + maxsign;
1108	return 0;
1109}
1110