1/* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
2/*
3 * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include <sys/types.h>
27#include <sys/uio.h>
28
29#include <stdio.h>
30#include <string.h>
31#include <unistd.h>
32#include <fcntl.h>
33#include <errno.h>
34
35#include "ssh2.h"
36#include "ssherr.h"
37#include "sshbuf.h"
38#include "cipher.h"
39#include "sshkey.h"
40#include "sshkey-xmss.h"
41#include "atomicio.h"
42#include "log.h"
43
44#include "xmss_fast.h"
45
46/* opaque internal XMSS state */
47#define XMSS_MAGIC		"xmss-state-v1"
48#define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
49struct ssh_xmss_state {
50	xmss_params	params;
51	u_int32_t	n, w, h, k;
52
53	bds_state	bds;
54	u_char		*stack;
55	u_int32_t	stackoffset;
56	u_char		*stacklevels;
57	u_char		*auth;
58	u_char		*keep;
59	u_char		*th_nodes;
60	u_char		*retain;
61	treehash_inst	*treehash;
62
63	u_int32_t	idx;		/* state read from file */
64	u_int32_t	maxidx;		/* restricted # of signatures */
65	int		have_state;	/* .state file exists */
66	int		lockfd;		/* locked in sshkey_xmss_get_state() */
67	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
68	char		*enc_ciphername;/* encrypt state with cipher */
69	u_char		*enc_keyiv;	/* encrypt state with key */
70	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
71};
72
73int	 sshkey_xmss_init_bds_state(struct sshkey *);
74int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
75void	 sshkey_xmss_free_bds(struct sshkey *);
76int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
77	    int *, int);
78int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
79	    struct sshbuf **);
80int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
81	    struct sshbuf **);
82int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
83int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
84
85#define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
86    0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
87
88int
89sshkey_xmss_init(struct sshkey *key, const char *name)
90{
91	struct ssh_xmss_state *state;
92
93	if (key->xmss_state != NULL)
94		return SSH_ERR_INVALID_FORMAT;
95	if (name == NULL)
96		return SSH_ERR_INVALID_FORMAT;
97	state = calloc(sizeof(struct ssh_xmss_state), 1);
98	if (state == NULL)
99		return SSH_ERR_ALLOC_FAIL;
100	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
101		state->n = 32;
102		state->w = 16;
103		state->h = 10;
104	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
105		state->n = 32;
106		state->w = 16;
107		state->h = 16;
108	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
109		state->n = 32;
110		state->w = 16;
111		state->h = 20;
112	} else {
113		free(state);
114		return SSH_ERR_KEY_TYPE_UNKNOWN;
115	}
116	if ((key->xmss_name = strdup(name)) == NULL) {
117		free(state);
118		return SSH_ERR_ALLOC_FAIL;
119	}
120	state->k = 2;	/* XXX hardcoded */
121	state->lockfd = -1;
122	if (xmss_set_params(&state->params, state->n, state->h, state->w,
123	    state->k) != 0) {
124		free(state);
125		return SSH_ERR_INVALID_FORMAT;
126	}
127	key->xmss_state = state;
128	return 0;
129}
130
131void
132sshkey_xmss_free_state(struct sshkey *key)
133{
134	struct ssh_xmss_state *state = key->xmss_state;
135
136	sshkey_xmss_free_bds(key);
137	if (state) {
138		if (state->enc_keyiv) {
139			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
140			free(state->enc_keyiv);
141		}
142		free(state->enc_ciphername);
143		free(state);
144	}
145	key->xmss_state = NULL;
146}
147
148#define SSH_XMSS_K2_MAGIC	"k=2"
149#define num_stack(x)		((x->h+1)*(x->n))
150#define num_stacklevels(x)	(x->h+1)
151#define num_auth(x)		((x->h)*(x->n))
152#define num_keep(x)		((x->h >> 1)*(x->n))
153#define num_th_nodes(x)		((x->h - x->k)*(x->n))
154#define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
155#define num_treehash(x)		((x->h) - (x->k))
156
157int
158sshkey_xmss_init_bds_state(struct sshkey *key)
159{
160	struct ssh_xmss_state *state = key->xmss_state;
161	u_int32_t i;
162
163	state->stackoffset = 0;
164	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
165	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
166	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
167	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
168	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
169	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
170	    (state->treehash = calloc(num_treehash(state),
171	    sizeof(treehash_inst))) == NULL) {
172		sshkey_xmss_free_bds(key);
173		return SSH_ERR_ALLOC_FAIL;
174	}
175	for (i = 0; i < state->h - state->k; i++)
176		state->treehash[i].node = &state->th_nodes[state->n*i];
177	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
178	    state->stacklevels, state->auth, state->keep, state->treehash,
179	    state->retain, 0);
180	return 0;
181}
182
183void
184sshkey_xmss_free_bds(struct sshkey *key)
185{
186	struct ssh_xmss_state *state = key->xmss_state;
187
188	if (state == NULL)
189		return;
190	free(state->stack);
191	free(state->stacklevels);
192	free(state->auth);
193	free(state->keep);
194	free(state->th_nodes);
195	free(state->retain);
196	free(state->treehash);
197	state->stack = NULL;
198	state->stacklevels = NULL;
199	state->auth = NULL;
200	state->keep = NULL;
201	state->th_nodes = NULL;
202	state->retain = NULL;
203	state->treehash = NULL;
204}
205
206void *
207sshkey_xmss_params(const struct sshkey *key)
208{
209	struct ssh_xmss_state *state = key->xmss_state;
210
211	if (state == NULL)
212		return NULL;
213	return &state->params;
214}
215
216void *
217sshkey_xmss_bds_state(const struct sshkey *key)
218{
219	struct ssh_xmss_state *state = key->xmss_state;
220
221	if (state == NULL)
222		return NULL;
223	return &state->bds;
224}
225
226int
227sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
228{
229	struct ssh_xmss_state *state = key->xmss_state;
230
231	if (lenp == NULL)
232		return SSH_ERR_INVALID_ARGUMENT;
233	if (state == NULL)
234		return SSH_ERR_INVALID_FORMAT;
235	*lenp = 4 + state->n +
236	    state->params.wots_par.keysize +
237	    state->h * state->n;
238	return 0;
239}
240
241size_t
242sshkey_xmss_pklen(const struct sshkey *key)
243{
244	struct ssh_xmss_state *state = key->xmss_state;
245
246	if (state == NULL)
247		return 0;
248	return state->n * 2;
249}
250
251size_t
252sshkey_xmss_sklen(const struct sshkey *key)
253{
254	struct ssh_xmss_state *state = key->xmss_state;
255
256	if (state == NULL)
257		return 0;
258	return state->n * 4 + 4;
259}
260
261int
262sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
263{
264	struct ssh_xmss_state *state = k->xmss_state;
265	const struct sshcipher *cipher;
266	size_t keylen = 0, ivlen = 0;
267
268	if (state == NULL)
269		return SSH_ERR_INVALID_ARGUMENT;
270	if ((cipher = cipher_by_name(ciphername)) == NULL)
271		return SSH_ERR_INTERNAL_ERROR;
272	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
273		return SSH_ERR_ALLOC_FAIL;
274	keylen = cipher_keylen(cipher);
275	ivlen = cipher_ivlen(cipher);
276	state->enc_keyiv_len = keylen + ivlen;
277	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
278		free(state->enc_ciphername);
279		state->enc_ciphername = NULL;
280		return SSH_ERR_ALLOC_FAIL;
281	}
282	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
283	return 0;
284}
285
286int
287sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
288{
289	struct ssh_xmss_state *state = k->xmss_state;
290	int r;
291
292	if (state == NULL || state->enc_keyiv == NULL ||
293	    state->enc_ciphername == NULL)
294		return SSH_ERR_INVALID_ARGUMENT;
295	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
296	    (r = sshbuf_put_string(b, state->enc_keyiv,
297	    state->enc_keyiv_len)) != 0)
298		return r;
299	return 0;
300}
301
302int
303sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
304{
305	struct ssh_xmss_state *state = k->xmss_state;
306	size_t len;
307	int r;
308
309	if (state == NULL)
310		return SSH_ERR_INVALID_ARGUMENT;
311	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
312	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
313		return r;
314	state->enc_keyiv_len = len;
315	return 0;
316}
317
318int
319sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
320    enum sshkey_serialize_rep opts)
321{
322	struct ssh_xmss_state *state = k->xmss_state;
323	u_char have_info = 1;
324	u_int32_t idx;
325	int r;
326
327	if (state == NULL)
328		return SSH_ERR_INVALID_ARGUMENT;
329	if (opts != SSHKEY_SERIALIZE_INFO)
330		return 0;
331	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
332	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
333	    (r = sshbuf_put_u32(b, idx)) != 0 ||
334	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
335		return r;
336	return 0;
337}
338
339int
340sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
341{
342	struct ssh_xmss_state *state = k->xmss_state;
343	u_char have_info;
344	int r;
345
346	if (state == NULL)
347		return SSH_ERR_INVALID_ARGUMENT;
348	/* optional */
349	if (sshbuf_len(b) == 0)
350		return 0;
351	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
352		return r;
353	if (have_info != 1)
354		return SSH_ERR_INVALID_ARGUMENT;
355	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
356	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
357		return r;
358	return 0;
359}
360
361int
362sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
363{
364	int r;
365	const char *name;
366
367	if (bits == 10) {
368		name = XMSS_SHA2_256_W16_H10_NAME;
369	} else if (bits == 16) {
370		name = XMSS_SHA2_256_W16_H16_NAME;
371	} else if (bits == 20) {
372		name = XMSS_SHA2_256_W16_H20_NAME;
373	} else {
374		name = XMSS_DEFAULT_NAME;
375	}
376	if ((r = sshkey_xmss_init(k, name)) != 0 ||
377	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
378	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
379		return r;
380	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
381	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
382		return SSH_ERR_ALLOC_FAIL;
383	}
384	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
385	    sshkey_xmss_params(k));
386	return 0;
387}
388
389int
390sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
391    int *have_file, int printerror)
392{
393	struct sshbuf *b = NULL, *enc = NULL;
394	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
395	u_int32_t len;
396	unsigned char buf[4], *data = NULL;
397
398	*have_file = 0;
399	if ((fd = open(filename, O_RDONLY)) >= 0) {
400		*have_file = 1;
401		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
402			PRINT("corrupt state file: %s", filename);
403			goto done;
404		}
405		len = PEEK_U32(buf);
406		if ((data = calloc(len, 1)) == NULL) {
407			ret = SSH_ERR_ALLOC_FAIL;
408			goto done;
409		}
410		if (atomicio(read, fd, data, len) != len) {
411			PRINT("cannot read blob: %s", filename);
412			goto done;
413		}
414		if ((enc = sshbuf_from(data, len)) == NULL) {
415			ret = SSH_ERR_ALLOC_FAIL;
416			goto done;
417		}
418		sshkey_xmss_free_bds(k);
419		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
420			ret = r;
421			goto done;
422		}
423		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
424			ret = r;
425			goto done;
426		}
427		ret = 0;
428	}
429done:
430	if (fd != -1)
431		close(fd);
432	free(data);
433	sshbuf_free(enc);
434	sshbuf_free(b);
435	return ret;
436}
437
438int
439sshkey_xmss_get_state(const struct sshkey *k, int printerror)
440{
441	struct ssh_xmss_state *state = k->xmss_state;
442	u_int32_t idx = 0;
443	char *filename = NULL;
444	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
445	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
446	int ret = SSH_ERR_INVALID_ARGUMENT, r;
447
448	if (state == NULL)
449		goto done;
450	/*
451	 * If maxidx is set, then we are allowed a limited number
452	 * of signatures, but don't need to access the disk.
453	 * Otherwise we need to deal with the on-disk state.
454	 */
455	if (state->maxidx) {
456		/* xmss_sk always contains the current state */
457		idx = PEEK_U32(k->xmss_sk);
458		if (idx < state->maxidx) {
459			state->allow_update = 1;
460			return 0;
461		}
462		return SSH_ERR_INVALID_ARGUMENT;
463	}
464	if ((filename = k->xmss_filename) == NULL)
465		goto done;
466	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
467	    asprintf(&statefile, "%s.state", filename) == -1 ||
468	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
469		ret = SSH_ERR_ALLOC_FAIL;
470		goto done;
471	}
472	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
473		ret = SSH_ERR_SYSTEM_ERROR;
474		PRINT("cannot open/create: %s", lockfile);
475		goto done;
476	}
477	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
478		if (errno != EWOULDBLOCK) {
479			ret = SSH_ERR_SYSTEM_ERROR;
480			PRINT("cannot lock: %s", lockfile);
481			goto done;
482		}
483		if (++tries > 10) {
484			ret = SSH_ERR_SYSTEM_ERROR;
485			PRINT("giving up on: %s", lockfile);
486			goto done;
487		}
488		usleep(1000*100*tries);
489	}
490	/* XXX no longer const */
491	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
492	    statefile, &have_state, printerror)) != 0) {
493		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
494		    ostatefile, &have_ostate, printerror)) == 0) {
495			state->allow_update = 1;
496			r = sshkey_xmss_forward_state(k, 1);
497			state->idx = PEEK_U32(k->xmss_sk);
498			state->allow_update = 0;
499		}
500	}
501	if (!have_state && !have_ostate) {
502		/* check that bds state is initialized */
503		if (state->bds.auth == NULL)
504			goto done;
505		PRINT("start from scratch idx 0: %u", state->idx);
506	} else if (r != 0) {
507		ret = r;
508		goto done;
509	}
510	if (state->idx + 1 < state->idx) {
511		PRINT("state wrap: %u", state->idx);
512		goto done;
513	}
514	state->have_state = have_state;
515	state->lockfd = lockfd;
516	state->allow_update = 1;
517	lockfd = -1;
518	ret = 0;
519done:
520	if (lockfd != -1)
521		close(lockfd);
522	free(lockfile);
523	free(statefile);
524	free(ostatefile);
525	return ret;
526}
527
528int
529sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
530{
531	struct ssh_xmss_state *state = k->xmss_state;
532	u_char *sig = NULL;
533	size_t required_siglen;
534	unsigned long long smlen;
535	u_char data;
536	int ret, r;
537
538	if (state == NULL || !state->allow_update)
539		return SSH_ERR_INVALID_ARGUMENT;
540	if (reserve == 0)
541		return SSH_ERR_INVALID_ARGUMENT;
542	if (state->idx + reserve <= state->idx)
543		return SSH_ERR_INVALID_ARGUMENT;
544	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
545		return r;
546	if ((sig = malloc(required_siglen)) == NULL)
547		return SSH_ERR_ALLOC_FAIL;
548	while (reserve-- > 0) {
549		state->idx = PEEK_U32(k->xmss_sk);
550		smlen = required_siglen;
551		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
552		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
553			r = SSH_ERR_INVALID_ARGUMENT;
554			break;
555		}
556	}
557	free(sig);
558	return r;
559}
560
561int
562sshkey_xmss_update_state(const struct sshkey *k, int printerror)
563{
564	struct ssh_xmss_state *state = k->xmss_state;
565	struct sshbuf *b = NULL, *enc = NULL;
566	u_int32_t idx = 0;
567	unsigned char buf[4];
568	char *filename = NULL;
569	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
570	int fd = -1;
571	int ret = SSH_ERR_INVALID_ARGUMENT;
572
573	if (state == NULL || !state->allow_update)
574		return ret;
575	if (state->maxidx) {
576		/* no update since the number of signatures is limited */
577		ret = 0;
578		goto done;
579	}
580	idx = PEEK_U32(k->xmss_sk);
581	if (idx == state->idx) {
582		/* no signature happened, no need to update */
583		ret = 0;
584		goto done;
585	} else if (idx != state->idx + 1) {
586		PRINT("more than one signature happened: idx %u state %u",
587		    idx, state->idx);
588		goto done;
589	}
590	state->idx = idx;
591	if ((filename = k->xmss_filename) == NULL)
592		goto done;
593	if (asprintf(&statefile, "%s.state", filename) == -1 ||
594	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
595	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
596		ret = SSH_ERR_ALLOC_FAIL;
597		goto done;
598	}
599	unlink(nstatefile);
600	if ((b = sshbuf_new()) == NULL) {
601		ret = SSH_ERR_ALLOC_FAIL;
602		goto done;
603	}
604	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
605		PRINT("SERLIALIZE FAILED: %d", ret);
606		goto done;
607	}
608	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
609		PRINT("ENCRYPT FAILED: %d", ret);
610		goto done;
611	}
612	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
613		ret = SSH_ERR_SYSTEM_ERROR;
614		PRINT("open new state file: %s", nstatefile);
615		goto done;
616	}
617	POKE_U32(buf, sshbuf_len(enc));
618	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
619		ret = SSH_ERR_SYSTEM_ERROR;
620		PRINT("write new state file hdr: %s", nstatefile);
621		close(fd);
622		goto done;
623	}
624	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
625	    sshbuf_len(enc)) {
626		ret = SSH_ERR_SYSTEM_ERROR;
627		PRINT("write new state file data: %s", nstatefile);
628		close(fd);
629		goto done;
630	}
631	if (fsync(fd) == -1) {
632		ret = SSH_ERR_SYSTEM_ERROR;
633		PRINT("sync new state file: %s", nstatefile);
634		close(fd);
635		goto done;
636	}
637	if (close(fd) == -1) {
638		ret = SSH_ERR_SYSTEM_ERROR;
639		PRINT("close new state file: %s", nstatefile);
640		goto done;
641	}
642	if (state->have_state) {
643		unlink(ostatefile);
644		if (link(statefile, ostatefile)) {
645			ret = SSH_ERR_SYSTEM_ERROR;
646			PRINT("backup state %s to %s", statefile, ostatefile);
647			goto done;
648		}
649	}
650	if (rename(nstatefile, statefile) == -1) {
651		ret = SSH_ERR_SYSTEM_ERROR;
652		PRINT("rename %s to %s", nstatefile, statefile);
653		goto done;
654	}
655	ret = 0;
656done:
657	if (state->lockfd != -1) {
658		close(state->lockfd);
659		state->lockfd = -1;
660	}
661	if (nstatefile)
662		unlink(nstatefile);
663	free(statefile);
664	free(ostatefile);
665	free(nstatefile);
666	sshbuf_free(b);
667	sshbuf_free(enc);
668	return ret;
669}
670
671int
672sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
673{
674	struct ssh_xmss_state *state = k->xmss_state;
675	treehash_inst *th;
676	u_int32_t i, node;
677	int r;
678
679	if (state == NULL)
680		return SSH_ERR_INVALID_ARGUMENT;
681	if (state->stack == NULL)
682		return SSH_ERR_INVALID_ARGUMENT;
683	state->stackoffset = state->bds.stackoffset;	/* copy back */
684	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
685	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
686	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
687	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
688	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
689	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
690	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
691	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
692	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
693	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
694		return r;
695	for (i = 0; i < num_treehash(state); i++) {
696		th = &state->treehash[i];
697		node = th->node - state->th_nodes;
698		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
699		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
700		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
701		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
702		    (r = sshbuf_put_u32(b, node)) != 0)
703			return r;
704	}
705	return 0;
706}
707
708int
709sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
710    enum sshkey_serialize_rep opts)
711{
712	struct ssh_xmss_state *state = k->xmss_state;
713	int r = SSH_ERR_INVALID_ARGUMENT;
714	u_char have_stack, have_filename, have_enc;
715
716	if (state == NULL)
717		return SSH_ERR_INVALID_ARGUMENT;
718	if ((r = sshbuf_put_u8(b, opts)) != 0)
719		return r;
720	switch (opts) {
721	case SSHKEY_SERIALIZE_STATE:
722		r = sshkey_xmss_serialize_state(k, b);
723		break;
724	case SSHKEY_SERIALIZE_FULL:
725		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
726			return r;
727		r = sshkey_xmss_serialize_state(k, b);
728		break;
729	case SSHKEY_SERIALIZE_SHIELD:
730		/* all of stack/filename/enc are optional */
731		have_stack = state->stack != NULL;
732		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
733			return r;
734		if (have_stack) {
735			state->idx = PEEK_U32(k->xmss_sk);	/* update */
736			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
737				return r;
738		}
739		have_filename = k->xmss_filename != NULL;
740		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
741			return r;
742		if (have_filename &&
743		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
744			return r;
745		have_enc = state->enc_keyiv != NULL;
746		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
747			return r;
748		if (have_enc &&
749		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
750			return r;
751		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
752		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
753			return r;
754		break;
755	case SSHKEY_SERIALIZE_DEFAULT:
756		r = 0;
757		break;
758	default:
759		r = SSH_ERR_INVALID_ARGUMENT;
760		break;
761	}
762	return r;
763}
764
765int
766sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
767{
768	struct ssh_xmss_state *state = k->xmss_state;
769	treehash_inst *th;
770	u_int32_t i, lh, node;
771	size_t ls, lsl, la, lk, ln, lr;
772	char *magic;
773	int r = SSH_ERR_INTERNAL_ERROR;
774
775	if (state == NULL)
776		return SSH_ERR_INVALID_ARGUMENT;
777	if (k->xmss_sk == NULL)
778		return SSH_ERR_INVALID_ARGUMENT;
779	if ((state->treehash = calloc(num_treehash(state),
780	    sizeof(treehash_inst))) == NULL)
781		return SSH_ERR_ALLOC_FAIL;
782	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
783	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
784	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
785	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
786	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
787	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
788	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
789	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
790	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
791	    (r = sshbuf_get_u32(b, &lh)) != 0)
792		goto out;
793	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
794		r = SSH_ERR_INVALID_ARGUMENT;
795		goto out;
796	}
797	/* XXX check stackoffset */
798	if (ls != num_stack(state) ||
799	    lsl != num_stacklevels(state) ||
800	    la != num_auth(state) ||
801	    lk != num_keep(state) ||
802	    ln != num_th_nodes(state) ||
803	    lr != num_retain(state) ||
804	    lh != num_treehash(state)) {
805		r = SSH_ERR_INVALID_ARGUMENT;
806		goto out;
807	}
808	for (i = 0; i < num_treehash(state); i++) {
809		th = &state->treehash[i];
810		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
811		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
812		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
813		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
814		    (r = sshbuf_get_u32(b, &node)) != 0)
815			goto out;
816		if (node < num_th_nodes(state))
817			th->node = &state->th_nodes[node];
818	}
819	POKE_U32(k->xmss_sk, state->idx);
820	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
821	    state->stacklevels, state->auth, state->keep, state->treehash,
822	    state->retain, 0);
823	/* success */
824	r = 0;
825 out:
826	free(magic);
827	return r;
828}
829
830int
831sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
832{
833	struct ssh_xmss_state *state = k->xmss_state;
834	enum sshkey_serialize_rep opts;
835	u_char have_state, have_stack, have_filename, have_enc;
836	int r;
837
838	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
839		return r;
840
841	opts = have_state;
842	switch (opts) {
843	case SSHKEY_SERIALIZE_DEFAULT:
844		r = 0;
845		break;
846	case SSHKEY_SERIALIZE_SHIELD:
847		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
848			return r;
849		if (have_stack &&
850		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
851			return r;
852		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
853			return r;
854		if (have_filename &&
855		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
856			return r;
857		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
858			return r;
859		if (have_enc &&
860		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
861			return r;
862		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
863		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
864			return r;
865		break;
866	case SSHKEY_SERIALIZE_STATE:
867		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
868			return r;
869		break;
870	case SSHKEY_SERIALIZE_FULL:
871		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
872		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
873			return r;
874		break;
875	default:
876		r = SSH_ERR_INVALID_FORMAT;
877		break;
878	}
879	return r;
880}
881
882int
883sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
884   struct sshbuf **retp)
885{
886	struct ssh_xmss_state *state = k->xmss_state;
887	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
888	struct sshcipher_ctx *ciphercontext = NULL;
889	const struct sshcipher *cipher;
890	u_char *cp, *key, *iv = NULL;
891	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
892	int r = SSH_ERR_INTERNAL_ERROR;
893
894	if (retp != NULL)
895		*retp = NULL;
896	if (state == NULL ||
897	    state->enc_keyiv == NULL ||
898	    state->enc_ciphername == NULL)
899		return SSH_ERR_INTERNAL_ERROR;
900	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
901		r = SSH_ERR_INTERNAL_ERROR;
902		goto out;
903	}
904	blocksize = cipher_blocksize(cipher);
905	keylen = cipher_keylen(cipher);
906	ivlen = cipher_ivlen(cipher);
907	authlen = cipher_authlen(cipher);
908	if (state->enc_keyiv_len != keylen + ivlen) {
909		r = SSH_ERR_INVALID_FORMAT;
910		goto out;
911	}
912	key = state->enc_keyiv;
913	if ((encrypted = sshbuf_new()) == NULL ||
914	    (encoded = sshbuf_new()) == NULL ||
915	    (padded = sshbuf_new()) == NULL ||
916	    (iv = malloc(ivlen)) == NULL) {
917		r = SSH_ERR_ALLOC_FAIL;
918		goto out;
919	}
920
921	/* replace first 4 bytes of IV with index to ensure uniqueness */
922	memcpy(iv, key + keylen, ivlen);
923	POKE_U32(iv, state->idx);
924
925	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
926	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
927		goto out;
928
929	/* padded state will be encrypted */
930	if ((r = sshbuf_putb(padded, b)) != 0)
931		goto out;
932	i = 0;
933	while (sshbuf_len(padded) % blocksize) {
934		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
935			goto out;
936	}
937	encrypted_len = sshbuf_len(padded);
938
939	/* header including the length of state is used as AAD */
940	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
941		goto out;
942	aadlen = sshbuf_len(encoded);
943
944	/* concat header and state */
945	if ((r = sshbuf_putb(encoded, padded)) != 0)
946		goto out;
947
948	/* reserve space for encryption of encoded data plus auth tag */
949	/* encrypt at offset addlen */
950	if ((r = sshbuf_reserve(encrypted,
951	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
952	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
953	    iv, ivlen, 1)) != 0 ||
954	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
955	    encrypted_len, aadlen, authlen)) != 0)
956		goto out;
957
958	/* success */
959	r = 0;
960 out:
961	if (retp != NULL) {
962		*retp = encrypted;
963		encrypted = NULL;
964	}
965	sshbuf_free(padded);
966	sshbuf_free(encoded);
967	sshbuf_free(encrypted);
968	cipher_free(ciphercontext);
969	free(iv);
970	return r;
971}
972
973int
974sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
975   struct sshbuf **retp)
976{
977	struct ssh_xmss_state *state = k->xmss_state;
978	struct sshbuf *copy = NULL, *decrypted = NULL;
979	struct sshcipher_ctx *ciphercontext = NULL;
980	const struct sshcipher *cipher = NULL;
981	u_char *key, *iv = NULL, *dp;
982	size_t keylen, ivlen, authlen, aadlen;
983	u_int blocksize, encrypted_len, index;
984	int r = SSH_ERR_INTERNAL_ERROR;
985
986	if (retp != NULL)
987		*retp = NULL;
988	if (state == NULL ||
989	    state->enc_keyiv == NULL ||
990	    state->enc_ciphername == NULL)
991		return SSH_ERR_INTERNAL_ERROR;
992	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
993		r = SSH_ERR_INVALID_FORMAT;
994		goto out;
995	}
996	blocksize = cipher_blocksize(cipher);
997	keylen = cipher_keylen(cipher);
998	ivlen = cipher_ivlen(cipher);
999	authlen = cipher_authlen(cipher);
1000	if (state->enc_keyiv_len != keylen + ivlen) {
1001		r = SSH_ERR_INTERNAL_ERROR;
1002		goto out;
1003	}
1004	key = state->enc_keyiv;
1005
1006	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1007	    (decrypted = sshbuf_new()) == NULL ||
1008	    (iv = malloc(ivlen)) == NULL) {
1009		r = SSH_ERR_ALLOC_FAIL;
1010		goto out;
1011	}
1012
1013	/* check magic */
1014	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1015	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1016		r = SSH_ERR_INVALID_FORMAT;
1017		goto out;
1018	}
1019	/* parse public portion */
1020	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1021	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1022	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1023		goto out;
1024
1025	/* check size of encrypted key blob */
1026	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1027		r = SSH_ERR_INVALID_FORMAT;
1028		goto out;
1029	}
1030	/* check that an appropriate amount of auth data is present */
1031	if (sshbuf_len(encoded) < authlen ||
1032	    sshbuf_len(encoded) - authlen < encrypted_len) {
1033		r = SSH_ERR_INVALID_FORMAT;
1034		goto out;
1035	}
1036
1037	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1038
1039	/* replace first 4 bytes of IV with index to ensure uniqueness */
1040	memcpy(iv, key + keylen, ivlen);
1041	POKE_U32(iv, index);
1042
1043	/* decrypt private state of key */
1044	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1045	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1046	    iv, ivlen, 0)) != 0 ||
1047	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1048	    encrypted_len, aadlen, authlen)) != 0)
1049		goto out;
1050
1051	/* there should be no trailing data */
1052	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1053		goto out;
1054	if (sshbuf_len(encoded) != 0) {
1055		r = SSH_ERR_INVALID_FORMAT;
1056		goto out;
1057	}
1058
1059	/* remove AAD */
1060	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1061		goto out;
1062	/* XXX encrypted includes unchecked padding */
1063
1064	/* success */
1065	r = 0;
1066	if (retp != NULL) {
1067		*retp = decrypted;
1068		decrypted = NULL;
1069	}
1070 out:
1071	cipher_free(ciphercontext);
1072	sshbuf_free(copy);
1073	sshbuf_free(decrypted);
1074	free(iv);
1075	return r;
1076}
1077
1078u_int32_t
1079sshkey_xmss_signatures_left(const struct sshkey *k)
1080{
1081	struct ssh_xmss_state *state = k->xmss_state;
1082	u_int32_t idx;
1083
1084	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1085	    state->maxidx) {
1086		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1087		if (idx < state->maxidx)
1088			return state->maxidx - idx;
1089	}
1090	return 0;
1091}
1092
1093int
1094sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1095{
1096	struct ssh_xmss_state *state = k->xmss_state;
1097
1098	if (sshkey_type_plain(k->type) != KEY_XMSS)
1099		return SSH_ERR_INVALID_ARGUMENT;
1100	if (maxsign == 0)
1101		return 0;
1102	if (state->idx + maxsign < state->idx)
1103		return SSH_ERR_INVALID_ARGUMENT;
1104	state->maxidx = state->idx + maxsign;
1105	return 0;
1106}
1107