1/* $OpenBSD: sshkey-xmss.c,v 1.3 2018/07/09 21:59:10 markus Exp $ */
2/*
3 * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include "includes.h"
27#ifdef WITH_XMSS
28
29#include <sys/types.h>
30#include <sys/uio.h>
31
32#include <stdio.h>
33#include <string.h>
34#include <unistd.h>
35#include <fcntl.h>
36#include <errno.h>
37#ifdef HAVE_SYS_FILE_H
38# include <sys/file.h>
39#endif
40
41#include "ssh2.h"
42#include "ssherr.h"
43#include "sshbuf.h"
44#include "cipher.h"
45#include "sshkey.h"
46#include "sshkey-xmss.h"
47#include "atomicio.h"
48
49#include "xmss_fast.h"
50
51/* opaque internal XMSS state */
52#define XMSS_MAGIC		"xmss-state-v1"
53#define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
54struct ssh_xmss_state {
55	xmss_params	params;
56	u_int32_t	n, w, h, k;
57
58	bds_state	bds;
59	u_char		*stack;
60	u_int32_t	stackoffset;
61	u_char		*stacklevels;
62	u_char		*auth;
63	u_char		*keep;
64	u_char		*th_nodes;
65	u_char		*retain;
66	treehash_inst	*treehash;
67
68	u_int32_t	idx;		/* state read from file */
69	u_int32_t	maxidx;		/* restricted # of signatures */
70	int		have_state;	/* .state file exists */
71	int		lockfd;		/* locked in sshkey_xmss_get_state() */
72	int		allow_update;	/* allow sshkey_xmss_update_state() */
73	char		*enc_ciphername;/* encrypt state with cipher */
74	u_char		*enc_keyiv;	/* encrypt state with key */
75	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
76};
77
78int	 sshkey_xmss_init_bds_state(struct sshkey *);
79int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
80void	 sshkey_xmss_free_bds(struct sshkey *);
81int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
82	    int *, sshkey_printfn *);
83int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
84	    struct sshbuf **);
85int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
86	    struct sshbuf **);
87int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
88int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
89
90#define PRINT(s...) do { if (pr) pr(s); } while (0)
91
92int
93sshkey_xmss_init(struct sshkey *key, const char *name)
94{
95	struct ssh_xmss_state *state;
96
97	if (key->xmss_state != NULL)
98		return SSH_ERR_INVALID_FORMAT;
99	if (name == NULL)
100		return SSH_ERR_INVALID_FORMAT;
101	state = calloc(sizeof(struct ssh_xmss_state), 1);
102	if (state == NULL)
103		return SSH_ERR_ALLOC_FAIL;
104	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
105		state->n = 32;
106		state->w = 16;
107		state->h = 10;
108	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
109		state->n = 32;
110		state->w = 16;
111		state->h = 16;
112	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
113		state->n = 32;
114		state->w = 16;
115		state->h = 20;
116	} else {
117		free(state);
118		return SSH_ERR_KEY_TYPE_UNKNOWN;
119	}
120	if ((key->xmss_name = strdup(name)) == NULL) {
121		free(state);
122		return SSH_ERR_ALLOC_FAIL;
123	}
124	state->k = 2;	/* XXX hardcoded */
125	state->lockfd = -1;
126	if (xmss_set_params(&state->params, state->n, state->h, state->w,
127	    state->k) != 0) {
128		free(state);
129		return SSH_ERR_INVALID_FORMAT;
130	}
131	key->xmss_state = state;
132	return 0;
133}
134
135void
136sshkey_xmss_free_state(struct sshkey *key)
137{
138	struct ssh_xmss_state *state = key->xmss_state;
139
140	sshkey_xmss_free_bds(key);
141	if (state) {
142		if (state->enc_keyiv) {
143			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
144			free(state->enc_keyiv);
145		}
146		free(state->enc_ciphername);
147		free(state);
148	}
149	key->xmss_state = NULL;
150}
151
152#define SSH_XMSS_K2_MAGIC	"k=2"
153#define num_stack(x)		((x->h+1)*(x->n))
154#define num_stacklevels(x)	(x->h+1)
155#define num_auth(x)		((x->h)*(x->n))
156#define num_keep(x)		((x->h >> 1)*(x->n))
157#define num_th_nodes(x)		((x->h - x->k)*(x->n))
158#define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
159#define num_treehash(x)		((x->h) - (x->k))
160
161int
162sshkey_xmss_init_bds_state(struct sshkey *key)
163{
164	struct ssh_xmss_state *state = key->xmss_state;
165	u_int32_t i;
166
167	state->stackoffset = 0;
168	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
169	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
170	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
171	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
172	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
173	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
174	    (state->treehash = calloc(num_treehash(state),
175	    sizeof(treehash_inst))) == NULL) {
176		sshkey_xmss_free_bds(key);
177		return SSH_ERR_ALLOC_FAIL;
178	}
179	for (i = 0; i < state->h - state->k; i++)
180		state->treehash[i].node = &state->th_nodes[state->n*i];
181	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
182	    state->stacklevels, state->auth, state->keep, state->treehash,
183	    state->retain, 0);
184	return 0;
185}
186
187void
188sshkey_xmss_free_bds(struct sshkey *key)
189{
190	struct ssh_xmss_state *state = key->xmss_state;
191
192	if (state == NULL)
193		return;
194	free(state->stack);
195	free(state->stacklevels);
196	free(state->auth);
197	free(state->keep);
198	free(state->th_nodes);
199	free(state->retain);
200	free(state->treehash);
201	state->stack = NULL;
202	state->stacklevels = NULL;
203	state->auth = NULL;
204	state->keep = NULL;
205	state->th_nodes = NULL;
206	state->retain = NULL;
207	state->treehash = NULL;
208}
209
210void *
211sshkey_xmss_params(const struct sshkey *key)
212{
213	struct ssh_xmss_state *state = key->xmss_state;
214
215	if (state == NULL)
216		return NULL;
217	return &state->params;
218}
219
220void *
221sshkey_xmss_bds_state(const struct sshkey *key)
222{
223	struct ssh_xmss_state *state = key->xmss_state;
224
225	if (state == NULL)
226		return NULL;
227	return &state->bds;
228}
229
230int
231sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
232{
233	struct ssh_xmss_state *state = key->xmss_state;
234
235	if (lenp == NULL)
236		return SSH_ERR_INVALID_ARGUMENT;
237	if (state == NULL)
238		return SSH_ERR_INVALID_FORMAT;
239	*lenp = 4 + state->n +
240	    state->params.wots_par.keysize +
241	    state->h * state->n;
242	return 0;
243}
244
245size_t
246sshkey_xmss_pklen(const struct sshkey *key)
247{
248	struct ssh_xmss_state *state = key->xmss_state;
249
250	if (state == NULL)
251		return 0;
252	return state->n * 2;
253}
254
255size_t
256sshkey_xmss_sklen(const struct sshkey *key)
257{
258	struct ssh_xmss_state *state = key->xmss_state;
259
260	if (state == NULL)
261		return 0;
262	return state->n * 4 + 4;
263}
264
265int
266sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
267{
268	struct ssh_xmss_state *state = k->xmss_state;
269	const struct sshcipher *cipher;
270	size_t keylen = 0, ivlen = 0;
271
272	if (state == NULL)
273		return SSH_ERR_INVALID_ARGUMENT;
274	if ((cipher = cipher_by_name(ciphername)) == NULL)
275		return SSH_ERR_INTERNAL_ERROR;
276	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
277		return SSH_ERR_ALLOC_FAIL;
278	keylen = cipher_keylen(cipher);
279	ivlen = cipher_ivlen(cipher);
280	state->enc_keyiv_len = keylen + ivlen;
281	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
282		free(state->enc_ciphername);
283		state->enc_ciphername = NULL;
284		return SSH_ERR_ALLOC_FAIL;
285	}
286	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
287	return 0;
288}
289
290int
291sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
292{
293	struct ssh_xmss_state *state = k->xmss_state;
294	int r;
295
296	if (state == NULL || state->enc_keyiv == NULL ||
297	    state->enc_ciphername == NULL)
298		return SSH_ERR_INVALID_ARGUMENT;
299	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
300	    (r = sshbuf_put_string(b, state->enc_keyiv,
301	    state->enc_keyiv_len)) != 0)
302		return r;
303	return 0;
304}
305
306int
307sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
308{
309	struct ssh_xmss_state *state = k->xmss_state;
310	size_t len;
311	int r;
312
313	if (state == NULL)
314		return SSH_ERR_INVALID_ARGUMENT;
315	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
316	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
317		return r;
318	state->enc_keyiv_len = len;
319	return 0;
320}
321
322int
323sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
324    enum sshkey_serialize_rep opts)
325{
326	struct ssh_xmss_state *state = k->xmss_state;
327	u_char have_info = 1;
328	u_int32_t idx;
329	int r;
330
331	if (state == NULL)
332		return SSH_ERR_INVALID_ARGUMENT;
333	if (opts != SSHKEY_SERIALIZE_INFO)
334		return 0;
335	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
336	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
337	    (r = sshbuf_put_u32(b, idx)) != 0 ||
338	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
339		return r;
340	return 0;
341}
342
343int
344sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
345{
346	struct ssh_xmss_state *state = k->xmss_state;
347	u_char have_info;
348	int r;
349
350	if (state == NULL)
351		return SSH_ERR_INVALID_ARGUMENT;
352	/* optional */
353	if (sshbuf_len(b) == 0)
354		return 0;
355	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
356		return r;
357	if (have_info != 1)
358		return SSH_ERR_INVALID_ARGUMENT;
359	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
360	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
361		return r;
362	return 0;
363}
364
365int
366sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
367{
368	int r;
369	const char *name;
370
371	if (bits == 10) {
372		name = XMSS_SHA2_256_W16_H10_NAME;
373	} else if (bits == 16) {
374		name = XMSS_SHA2_256_W16_H16_NAME;
375	} else if (bits == 20) {
376		name = XMSS_SHA2_256_W16_H20_NAME;
377	} else {
378		name = XMSS_DEFAULT_NAME;
379	}
380	if ((r = sshkey_xmss_init(k, name)) != 0 ||
381	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
382	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
383		return r;
384	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
385	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
386		return SSH_ERR_ALLOC_FAIL;
387	}
388	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
389	    sshkey_xmss_params(k));
390	return 0;
391}
392
393int
394sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
395    int *have_file, sshkey_printfn *pr)
396{
397	struct sshbuf *b = NULL, *enc = NULL;
398	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
399	u_int32_t len;
400	unsigned char buf[4], *data = NULL;
401
402	*have_file = 0;
403	if ((fd = open(filename, O_RDONLY)) >= 0) {
404		*have_file = 1;
405		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
406			PRINT("%s: corrupt state file: %s", __func__, filename);
407			goto done;
408		}
409		len = PEEK_U32(buf);
410		if ((data = calloc(len, 1)) == NULL) {
411			ret = SSH_ERR_ALLOC_FAIL;
412			goto done;
413		}
414		if (atomicio(read, fd, data, len) != len) {
415			PRINT("%s: cannot read blob: %s", __func__, filename);
416			goto done;
417		}
418		if ((enc = sshbuf_from(data, len)) == NULL) {
419			ret = SSH_ERR_ALLOC_FAIL;
420			goto done;
421		}
422		sshkey_xmss_free_bds(k);
423		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
424			ret = r;
425			goto done;
426		}
427		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
428			ret = r;
429			goto done;
430		}
431		ret = 0;
432	}
433done:
434	if (fd != -1)
435		close(fd);
436	free(data);
437	sshbuf_free(enc);
438	sshbuf_free(b);
439	return ret;
440}
441
442int
443sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
444{
445	struct ssh_xmss_state *state = k->xmss_state;
446	u_int32_t idx = 0;
447	char *filename = NULL;
448	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
449	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
450	int ret = SSH_ERR_INVALID_ARGUMENT, r;
451
452	if (state == NULL)
453		goto done;
454	/*
455	 * If maxidx is set, then we are allowed a limited number
456	 * of signatures, but don't need to access the disk.
457	 * Otherwise we need to deal with the on-disk state.
458	 */
459	if (state->maxidx) {
460		/* xmss_sk always contains the current state */
461		idx = PEEK_U32(k->xmss_sk);
462		if (idx < state->maxidx) {
463			state->allow_update = 1;
464			return 0;
465		}
466		return SSH_ERR_INVALID_ARGUMENT;
467	}
468	if ((filename = k->xmss_filename) == NULL)
469		goto done;
470	if (asprintf(&lockfile, "%s.lock", filename) < 0 ||
471	    asprintf(&statefile, "%s.state", filename) < 0 ||
472	    asprintf(&ostatefile, "%s.ostate", filename) < 0) {
473		ret = SSH_ERR_ALLOC_FAIL;
474		goto done;
475	}
476	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) < 0) {
477		ret = SSH_ERR_SYSTEM_ERROR;
478		PRINT("%s: cannot open/create: %s", __func__, lockfile);
479		goto done;
480	}
481	while (flock(lockfd, LOCK_EX|LOCK_NB) < 0) {
482		if (errno != EWOULDBLOCK) {
483			ret = SSH_ERR_SYSTEM_ERROR;
484			PRINT("%s: cannot lock: %s", __func__, lockfile);
485			goto done;
486		}
487		if (++tries > 10) {
488			ret = SSH_ERR_SYSTEM_ERROR;
489			PRINT("%s: giving up on: %s", __func__, lockfile);
490			goto done;
491		}
492		usleep(1000*100*tries);
493	}
494	/* XXX no longer const */
495	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
496	    statefile, &have_state, pr)) != 0) {
497		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498		    ostatefile, &have_ostate, pr)) == 0) {
499			state->allow_update = 1;
500			r = sshkey_xmss_forward_state(k, 1);
501			state->idx = PEEK_U32(k->xmss_sk);
502			state->allow_update = 0;
503		}
504	}
505	if (!have_state && !have_ostate) {
506		/* check that bds state is initialized */
507		if (state->bds.auth == NULL)
508			goto done;
509		PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
510	} else if (r != 0) {
511		ret = r;
512		goto done;
513	}
514	if (state->idx + 1 < state->idx) {
515		PRINT("%s: state wrap: %u", __func__, state->idx);
516		goto done;
517	}
518	state->have_state = have_state;
519	state->lockfd = lockfd;
520	state->allow_update = 1;
521	lockfd = -1;
522	ret = 0;
523done:
524	if (lockfd != -1)
525		close(lockfd);
526	free(lockfile);
527	free(statefile);
528	free(ostatefile);
529	return ret;
530}
531
532int
533sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
534{
535	struct ssh_xmss_state *state = k->xmss_state;
536	u_char *sig = NULL;
537	size_t required_siglen;
538	unsigned long long smlen;
539	u_char data;
540	int ret, r;
541
542	if (state == NULL || !state->allow_update)
543		return SSH_ERR_INVALID_ARGUMENT;
544	if (reserve == 0)
545		return SSH_ERR_INVALID_ARGUMENT;
546	if (state->idx + reserve <= state->idx)
547		return SSH_ERR_INVALID_ARGUMENT;
548	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
549		return r;
550	if ((sig = malloc(required_siglen)) == NULL)
551		return SSH_ERR_ALLOC_FAIL;
552	while (reserve-- > 0) {
553		state->idx = PEEK_U32(k->xmss_sk);
554		smlen = required_siglen;
555		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
556		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
557			r = SSH_ERR_INVALID_ARGUMENT;
558			break;
559		}
560	}
561	free(sig);
562	return r;
563}
564
565int
566sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
567{
568	struct ssh_xmss_state *state = k->xmss_state;
569	struct sshbuf *b = NULL, *enc = NULL;
570	u_int32_t idx = 0;
571	unsigned char buf[4];
572	char *filename = NULL;
573	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
574	int fd = -1;
575	int ret = SSH_ERR_INVALID_ARGUMENT;
576
577	if (state == NULL || !state->allow_update)
578		return ret;
579	if (state->maxidx) {
580		/* no update since the number of signatures is limited */
581		ret = 0;
582		goto done;
583	}
584	idx = PEEK_U32(k->xmss_sk);
585	if (idx == state->idx) {
586		/* no signature happened, no need to update */
587		ret = 0;
588		goto done;
589	} else if (idx != state->idx + 1) {
590		PRINT("%s: more than one signature happened: idx %u state %u",
591		     __func__, idx, state->idx);
592		goto done;
593	}
594	state->idx = idx;
595	if ((filename = k->xmss_filename) == NULL)
596		goto done;
597	if (asprintf(&statefile, "%s.state", filename) < 0 ||
598	    asprintf(&ostatefile, "%s.ostate", filename) < 0 ||
599	    asprintf(&nstatefile, "%s.nstate", filename) < 0) {
600		ret = SSH_ERR_ALLOC_FAIL;
601		goto done;
602	}
603	unlink(nstatefile);
604	if ((b = sshbuf_new()) == NULL) {
605		ret = SSH_ERR_ALLOC_FAIL;
606		goto done;
607	}
608	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
609		PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
610		goto done;
611	}
612	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
613		PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
614		goto done;
615	}
616	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) < 0) {
617		ret = SSH_ERR_SYSTEM_ERROR;
618		PRINT("%s: open new state file: %s", __func__, nstatefile);
619		goto done;
620	}
621	POKE_U32(buf, sshbuf_len(enc));
622	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
623		ret = SSH_ERR_SYSTEM_ERROR;
624		PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
625		close(fd);
626		goto done;
627	}
628	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
629	    sshbuf_len(enc)) {
630		ret = SSH_ERR_SYSTEM_ERROR;
631		PRINT("%s: write new state file data: %s", __func__, nstatefile);
632		close(fd);
633		goto done;
634	}
635	if (fsync(fd) < 0) {
636		ret = SSH_ERR_SYSTEM_ERROR;
637		PRINT("%s: sync new state file: %s", __func__, nstatefile);
638		close(fd);
639		goto done;
640	}
641	if (close(fd) < 0) {
642		ret = SSH_ERR_SYSTEM_ERROR;
643		PRINT("%s: close new state file: %s", __func__, nstatefile);
644		goto done;
645	}
646	if (state->have_state) {
647		unlink(ostatefile);
648		if (link(statefile, ostatefile)) {
649			ret = SSH_ERR_SYSTEM_ERROR;
650			PRINT("%s: backup state %s to %s", __func__, statefile,
651			    ostatefile);
652			goto done;
653		}
654	}
655	if (rename(nstatefile, statefile) < 0) {
656		ret = SSH_ERR_SYSTEM_ERROR;
657		PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
658		goto done;
659	}
660	ret = 0;
661done:
662	if (state->lockfd != -1) {
663		close(state->lockfd);
664		state->lockfd = -1;
665	}
666	if (nstatefile)
667		unlink(nstatefile);
668	free(statefile);
669	free(ostatefile);
670	free(nstatefile);
671	sshbuf_free(b);
672	sshbuf_free(enc);
673	return ret;
674}
675
676int
677sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
678{
679	struct ssh_xmss_state *state = k->xmss_state;
680	treehash_inst *th;
681	u_int32_t i, node;
682	int r;
683
684	if (state == NULL)
685		return SSH_ERR_INVALID_ARGUMENT;
686	if (state->stack == NULL)
687		return SSH_ERR_INVALID_ARGUMENT;
688	state->stackoffset = state->bds.stackoffset;	/* copy back */
689	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
690	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
691	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
692	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
693	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
694	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
695	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
696	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
697	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
698	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
699		return r;
700	for (i = 0; i < num_treehash(state); i++) {
701		th = &state->treehash[i];
702		node = th->node - state->th_nodes;
703		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
704		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
705		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
706		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
707		    (r = sshbuf_put_u32(b, node)) != 0)
708			return r;
709	}
710	return 0;
711}
712
713int
714sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
715    enum sshkey_serialize_rep opts)
716{
717	struct ssh_xmss_state *state = k->xmss_state;
718	int r = SSH_ERR_INVALID_ARGUMENT;
719
720	if (state == NULL)
721		return SSH_ERR_INVALID_ARGUMENT;
722	if ((r = sshbuf_put_u8(b, opts)) != 0)
723		return r;
724	switch (opts) {
725	case SSHKEY_SERIALIZE_STATE:
726		r = sshkey_xmss_serialize_state(k, b);
727		break;
728	case SSHKEY_SERIALIZE_FULL:
729		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
730			break;
731		r = sshkey_xmss_serialize_state(k, b);
732		break;
733	case SSHKEY_SERIALIZE_DEFAULT:
734		r = 0;
735		break;
736	default:
737		r = SSH_ERR_INVALID_ARGUMENT;
738		break;
739	}
740	return r;
741}
742
743int
744sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
745{
746	struct ssh_xmss_state *state = k->xmss_state;
747	treehash_inst *th;
748	u_int32_t i, lh, node;
749	size_t ls, lsl, la, lk, ln, lr;
750	char *magic;
751	int r;
752
753	if (state == NULL)
754		return SSH_ERR_INVALID_ARGUMENT;
755	if (k->xmss_sk == NULL)
756		return SSH_ERR_INVALID_ARGUMENT;
757	if ((state->treehash = calloc(num_treehash(state),
758	    sizeof(treehash_inst))) == NULL)
759		return SSH_ERR_ALLOC_FAIL;
760	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
761	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
762	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
763	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
764	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
765	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
766	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
767	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
768	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
769	    (r = sshbuf_get_u32(b, &lh)) != 0)
770		return r;
771	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0)
772		return SSH_ERR_INVALID_ARGUMENT;
773	/* XXX check stackoffset */
774	if (ls != num_stack(state) ||
775	    lsl != num_stacklevels(state) ||
776	    la != num_auth(state) ||
777	    lk != num_keep(state) ||
778	    ln != num_th_nodes(state) ||
779	    lr != num_retain(state) ||
780	    lh != num_treehash(state))
781		return SSH_ERR_INVALID_ARGUMENT;
782	for (i = 0; i < num_treehash(state); i++) {
783		th = &state->treehash[i];
784		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
785		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
786		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
787		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
788		    (r = sshbuf_get_u32(b, &node)) != 0)
789			return r;
790		if (node < num_th_nodes(state))
791			th->node = &state->th_nodes[node];
792	}
793	POKE_U32(k->xmss_sk, state->idx);
794	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
795	    state->stacklevels, state->auth, state->keep, state->treehash,
796	    state->retain, 0);
797	return 0;
798}
799
800int
801sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
802{
803	enum sshkey_serialize_rep opts;
804	u_char have_state;
805	int r;
806
807	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
808		return r;
809
810	opts = have_state;
811	switch (opts) {
812	case SSHKEY_SERIALIZE_DEFAULT:
813		r = 0;
814		break;
815	case SSHKEY_SERIALIZE_STATE:
816		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
817			return r;
818		break;
819	case SSHKEY_SERIALIZE_FULL:
820		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
821		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
822			return r;
823		break;
824	default:
825		r = SSH_ERR_INVALID_FORMAT;
826		break;
827	}
828	return r;
829}
830
831int
832sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
833   struct sshbuf **retp)
834{
835	struct ssh_xmss_state *state = k->xmss_state;
836	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
837	struct sshcipher_ctx *ciphercontext = NULL;
838	const struct sshcipher *cipher;
839	u_char *cp, *key, *iv = NULL;
840	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
841	int r = SSH_ERR_INTERNAL_ERROR;
842
843	if (retp != NULL)
844		*retp = NULL;
845	if (state == NULL ||
846	    state->enc_keyiv == NULL ||
847	    state->enc_ciphername == NULL)
848		return SSH_ERR_INTERNAL_ERROR;
849	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
850		r = SSH_ERR_INTERNAL_ERROR;
851		goto out;
852	}
853	blocksize = cipher_blocksize(cipher);
854	keylen = cipher_keylen(cipher);
855	ivlen = cipher_ivlen(cipher);
856	authlen = cipher_authlen(cipher);
857	if (state->enc_keyiv_len != keylen + ivlen) {
858		r = SSH_ERR_INVALID_FORMAT;
859		goto out;
860	}
861	key = state->enc_keyiv;
862	if ((encrypted = sshbuf_new()) == NULL ||
863	    (encoded = sshbuf_new()) == NULL ||
864	    (padded = sshbuf_new()) == NULL ||
865	    (iv = malloc(ivlen)) == NULL) {
866		r = SSH_ERR_ALLOC_FAIL;
867		goto out;
868	}
869
870	/* replace first 4 bytes of IV with index to ensure uniqueness */
871	memcpy(iv, key + keylen, ivlen);
872	POKE_U32(iv, state->idx);
873
874	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
875	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
876		goto out;
877
878	/* padded state will be encrypted */
879	if ((r = sshbuf_putb(padded, b)) != 0)
880		goto out;
881	i = 0;
882	while (sshbuf_len(padded) % blocksize) {
883		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
884			goto out;
885	}
886	encrypted_len = sshbuf_len(padded);
887
888	/* header including the length of state is used as AAD */
889	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
890		goto out;
891	aadlen = sshbuf_len(encoded);
892
893	/* concat header and state */
894	if ((r = sshbuf_putb(encoded, padded)) != 0)
895		goto out;
896
897	/* reserve space for encryption of encoded data plus auth tag */
898	/* encrypt at offset addlen */
899	if ((r = sshbuf_reserve(encrypted,
900	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
901	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
902	    iv, ivlen, 1)) != 0 ||
903	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
904	    encrypted_len, aadlen, authlen)) != 0)
905		goto out;
906
907	/* success */
908	r = 0;
909 out:
910	if (retp != NULL) {
911		*retp = encrypted;
912		encrypted = NULL;
913	}
914	sshbuf_free(padded);
915	sshbuf_free(encoded);
916	sshbuf_free(encrypted);
917	cipher_free(ciphercontext);
918	free(iv);
919	return r;
920}
921
922int
923sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
924   struct sshbuf **retp)
925{
926	struct ssh_xmss_state *state = k->xmss_state;
927	struct sshbuf *copy = NULL, *decrypted = NULL;
928	struct sshcipher_ctx *ciphercontext = NULL;
929	const struct sshcipher *cipher = NULL;
930	u_char *key, *iv = NULL, *dp;
931	size_t keylen, ivlen, authlen, aadlen;
932	u_int blocksize, encrypted_len, index;
933	int r = SSH_ERR_INTERNAL_ERROR;
934
935	if (retp != NULL)
936		*retp = NULL;
937	if (state == NULL ||
938	    state->enc_keyiv == NULL ||
939	    state->enc_ciphername == NULL)
940		return SSH_ERR_INTERNAL_ERROR;
941	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
942		r = SSH_ERR_INVALID_FORMAT;
943		goto out;
944	}
945	blocksize = cipher_blocksize(cipher);
946	keylen = cipher_keylen(cipher);
947	ivlen = cipher_ivlen(cipher);
948	authlen = cipher_authlen(cipher);
949	if (state->enc_keyiv_len != keylen + ivlen) {
950		r = SSH_ERR_INTERNAL_ERROR;
951		goto out;
952	}
953	key = state->enc_keyiv;
954
955	if ((copy = sshbuf_fromb(encoded)) == NULL ||
956	    (decrypted = sshbuf_new()) == NULL ||
957	    (iv = malloc(ivlen)) == NULL) {
958		r = SSH_ERR_ALLOC_FAIL;
959		goto out;
960	}
961
962	/* check magic */
963	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
964	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
965		r = SSH_ERR_INVALID_FORMAT;
966		goto out;
967	}
968	/* parse public portion */
969	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
970	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
971	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
972		goto out;
973
974	/* check size of encrypted key blob */
975	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
976		r = SSH_ERR_INVALID_FORMAT;
977		goto out;
978	}
979	/* check that an appropriate amount of auth data is present */
980	if (sshbuf_len(encoded) < encrypted_len + authlen) {
981		r = SSH_ERR_INVALID_FORMAT;
982		goto out;
983	}
984
985	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
986
987	/* replace first 4 bytes of IV with index to ensure uniqueness */
988	memcpy(iv, key + keylen, ivlen);
989	POKE_U32(iv, index);
990
991	/* decrypt private state of key */
992	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
993	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
994	    iv, ivlen, 0)) != 0 ||
995	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
996	    encrypted_len, aadlen, authlen)) != 0)
997		goto out;
998
999	/* there should be no trailing data */
1000	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1001		goto out;
1002	if (sshbuf_len(encoded) != 0) {
1003		r = SSH_ERR_INVALID_FORMAT;
1004		goto out;
1005	}
1006
1007	/* remove AAD */
1008	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1009		goto out;
1010	/* XXX encrypted includes unchecked padding */
1011
1012	/* success */
1013	r = 0;
1014	if (retp != NULL) {
1015		*retp = decrypted;
1016		decrypted = NULL;
1017	}
1018 out:
1019	cipher_free(ciphercontext);
1020	sshbuf_free(copy);
1021	sshbuf_free(decrypted);
1022	free(iv);
1023	return r;
1024}
1025
1026u_int32_t
1027sshkey_xmss_signatures_left(const struct sshkey *k)
1028{
1029	struct ssh_xmss_state *state = k->xmss_state;
1030	u_int32_t idx;
1031
1032	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1033	    state->maxidx) {
1034		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1035		if (idx < state->maxidx)
1036			return state->maxidx - idx;
1037	}
1038	return 0;
1039}
1040
1041int
1042sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1043{
1044	struct ssh_xmss_state *state = k->xmss_state;
1045
1046	if (sshkey_type_plain(k->type) != KEY_XMSS)
1047		return SSH_ERR_INVALID_ARGUMENT;
1048	if (maxsign == 0)
1049		return 0;
1050	if (state->idx + maxsign < state->idx)
1051		return SSH_ERR_INVALID_ARGUMENT;
1052	state->maxidx = state->idx + maxsign;
1053	return 0;
1054}
1055#endif /* WITH_XMSS */
1056