138032Speter/*	$NetBSD: sshkey-xmss.c,v 1.10 2023/08/03 07:59:32 mrg Exp $	*/
238032Speter/* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
338032Speter/*
438032Speter * Copyright (c) 2017 Markus Friedl.  All rights reserved.
538032Speter *
638032Speter * Redistribution and use in source and binary forms, with or without
738032Speter * modification, are permitted provided that the following conditions
838032Speter * are met:
938032Speter * 1. Redistributions of source code must retain the above copyright
1038032Speter *    notice, this list of conditions and the following disclaimer.
1138032Speter * 2. Redistributions in binary form must reproduce the above copyright
1238032Speter *    notice, this list of conditions and the following disclaimer in the
1338032Speter *    documentation and/or other materials provided with the distribution.
1438032Speter *
1538032Speter * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
1638032Speter * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
1738032Speter * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
1838032Speter * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
1938032Speter * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
2038032Speter * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
2138032Speter * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
2238032Speter * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
2338032Speter * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
2438032Speter * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2538032Speter */
2638032Speter#include "includes.h"
2738032Speter__RCSID("$NetBSD: sshkey-xmss.c,v 1.10 2023/08/03 07:59:32 mrg Exp $");
2838032Speter
2938032Speter#include <sys/types.h>
3038032Speter#include <sys/uio.h>
3138032Speter
3238032Speter#include <stdio.h>
3338032Speter#include <string.h>
3438032Speter#include <unistd.h>
3538032Speter#include <fcntl.h>
3638032Speter#include <errno.h>
3738032Speter
3838032Speter#include "ssh2.h"
3938032Speter#include "ssherr.h"
4038032Speter#include "sshbuf.h"
4138032Speter#include "cipher.h"
4238032Speter#include "sshkey.h"
4338032Speter#include "sshkey-xmss.h"
4438032Speter#include "atomicio.h"
4538032Speter#include "log.h"
4638032Speter
4738032Speter#include "xmss_fast.h"
4838032Speter
4938032Speter/* opaque internal XMSS state */
5038032Speter#define XMSS_MAGIC		"xmss-state-v1"
5138032Speter#define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
5238032Speterstruct ssh_xmss_state {
5338032Speter	xmss_params	params;
5438032Speter	u_int32_t	n, w, h, k;
5538032Speter
5638032Speter	bds_state	bds;
5738032Speter	u_char		*stack;
5838032Speter	u_int32_t	stackoffset;
5938032Speter	u_char		*stacklevels;
6038032Speter	u_char		*auth;
6138032Speter	u_char		*keep;
6238032Speter	u_char		*th_nodes;
6338032Speter	u_char		*retain;
6438032Speter	treehash_inst	*treehash;
6538032Speter
6638032Speter	u_int32_t	idx;		/* state read from file */
6738032Speter	u_int32_t	maxidx;		/* restricted # of signatures */
6838032Speter	int		have_state;	/* .state file exists */
6938032Speter	int		lockfd;		/* locked in sshkey_xmss_get_state() */
7038032Speter	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
7138032Speter	char		*enc_ciphername;/* encrypt state with cipher */
7238032Speter	u_char		*enc_keyiv;	/* encrypt state with key */
7338032Speter	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
7438032Speter};
7538032Speter
7638032Speterint	 sshkey_xmss_init_bds_state(struct sshkey *);
7738032Speterint	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
7838032Spetervoid	 sshkey_xmss_free_bds(struct sshkey *);
7938032Speterint	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
8038032Speter	    int *, int);
8138032Speterint	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
8238032Speter	    struct sshbuf **);
8338032Speterint	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
8438032Speter	    struct sshbuf **);
8538032Speterint	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
8638032Speterint	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
8738032Speter
8838032Speter#define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
8938032Speter    0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
9038032Speter
9138032Speterint
9238032Spetersshkey_xmss_init(struct sshkey *key, const char *name)
9338032Speter{
9438032Speter	struct ssh_xmss_state *state;
9538032Speter
9638032Speter	if (key->xmss_state != NULL)
9738032Speter		return SSH_ERR_INVALID_FORMAT;
9838032Speter	if (name == NULL)
9938032Speter		return SSH_ERR_INVALID_FORMAT;
10038032Speter	state = calloc(sizeof(struct ssh_xmss_state), 1);
10138032Speter	if (state == NULL)
10238032Speter		return SSH_ERR_ALLOC_FAIL;
10338032Speter	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
10438032Speter		state->n = 32;
10538032Speter		state->w = 16;
10638032Speter		state->h = 10;
10738032Speter	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
10838032Speter		state->n = 32;
10938032Speter		state->w = 16;
11038032Speter		state->h = 16;
11138032Speter	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
11238032Speter		state->n = 32;
11338032Speter		state->w = 16;
11438032Speter		state->h = 20;
11538032Speter	} else {
11638032Speter		free(state);
11738032Speter		return SSH_ERR_KEY_TYPE_UNKNOWN;
11838032Speter	}
11938032Speter	if ((key->xmss_name = strdup(name)) == NULL) {
12038032Speter		free(state);
12138032Speter		return SSH_ERR_ALLOC_FAIL;
12238032Speter	}
12338032Speter	state->k = 2;	/* XXX hardcoded */
12438032Speter	state->lockfd = -1;
12538032Speter	if (xmss_set_params(&state->params, state->n, state->h, state->w,
12638032Speter	    state->k) != 0) {
12738032Speter		free(state);
12838032Speter		return SSH_ERR_INVALID_FORMAT;
12938032Speter	}
13038032Speter	key->xmss_state = state;
13138032Speter	return 0;
13238032Speter}
13338032Speter
13438032Spetervoid
13538032Spetersshkey_xmss_free_state(struct sshkey *key)
13638032Speter{
13738032Speter	struct ssh_xmss_state *state = key->xmss_state;
13838032Speter
13938032Speter	sshkey_xmss_free_bds(key);
14038032Speter	if (state) {
14138032Speter		if (state->enc_keyiv) {
14238032Speter			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
14338032Speter			free(state->enc_keyiv);
14438032Speter		}
14538032Speter		free(state->enc_ciphername);
14638032Speter		free(state);
14738032Speter	}
14838032Speter	key->xmss_state = NULL;
14938032Speter}
15038032Speter
15138032Speter#define SSH_XMSS_K2_MAGIC	"k=2"
15238032Speter#define num_stack(x)		((x->h+1)*(x->n))
15338032Speter#define num_stacklevels(x)	(x->h+1)
15438032Speter#define num_auth(x)		((x->h)*(x->n))
15538032Speter#define num_keep(x)		((x->h >> 1)*(x->n))
15638032Speter#define num_th_nodes(x)		((x->h - x->k)*(x->n))
15738032Speter#define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
15838032Speter#define num_treehash(x)		((x->h) - (x->k))
15938032Speter
16038032Speterint
16138032Spetersshkey_xmss_init_bds_state(struct sshkey *key)
16238032Speter{
16338032Speter	struct ssh_xmss_state *state = key->xmss_state;
16438032Speter	u_int32_t i;
16538032Speter
16638032Speter	state->stackoffset = 0;
16738032Speter	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
16838032Speter	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
16938032Speter	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
17038032Speter	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
17138032Speter	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
17238032Speter	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
17338032Speter	    (state->treehash = calloc(num_treehash(state),
17438032Speter	    sizeof(treehash_inst))) == NULL) {
17538032Speter		sshkey_xmss_free_bds(key);
17638032Speter		return SSH_ERR_ALLOC_FAIL;
17738032Speter	}
17838032Speter	for (i = 0; i < state->h - state->k; i++)
17938032Speter		state->treehash[i].node = &state->th_nodes[state->n*i];
18038032Speter	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
18138032Speter	    state->stacklevels, state->auth, state->keep, state->treehash,
18238032Speter	    state->retain, 0);
18338032Speter	return 0;
18438032Speter}
18538032Speter
18638032Spetervoid
18738032Spetersshkey_xmss_free_bds(struct sshkey *key)
18838032Speter{
18938032Speter	struct ssh_xmss_state *state = key->xmss_state;
19038032Speter
19138032Speter	if (state == NULL)
19238032Speter		return;
19338032Speter	free(state->stack);
19438032Speter	free(state->stacklevels);
19538032Speter	free(state->auth);
19638032Speter	free(state->keep);
19738032Speter	free(state->th_nodes);
19838032Speter	free(state->retain);
19938032Speter	free(state->treehash);
20038032Speter	state->stack = NULL;
20138032Speter	state->stacklevels = NULL;
20238032Speter	state->auth = NULL;
20338032Speter	state->keep = NULL;
20438032Speter	state->th_nodes = NULL;
20538032Speter	state->retain = NULL;
20638032Speter	state->treehash = NULL;
20738032Speter}
20838032Speter
20938032Spetervoid *
21038032Spetersshkey_xmss_params(const struct sshkey *key)
21138032Speter{
21238032Speter	struct ssh_xmss_state *state = key->xmss_state;
21338032Speter
21438032Speter	if (state == NULL)
21538032Speter		return NULL;
21638032Speter	return &state->params;
21738032Speter}
21838032Speter
21938032Spetervoid *
22038032Spetersshkey_xmss_bds_state(const struct sshkey *key)
22138032Speter{
22238032Speter	struct ssh_xmss_state *state = key->xmss_state;
22338032Speter
22438032Speter	if (state == NULL)
22538032Speter		return NULL;
22638032Speter	return &state->bds;
22738032Speter}
22838032Speter
22938032Speterint
23038032Spetersshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
23138032Speter{
23238032Speter	struct ssh_xmss_state *state = key->xmss_state;
23338032Speter
23438032Speter	if (lenp == NULL)
23538032Speter		return SSH_ERR_INVALID_ARGUMENT;
23638032Speter	if (state == NULL)
23738032Speter		return SSH_ERR_INVALID_FORMAT;
23838032Speter	*lenp = 4 + state->n +
23938032Speter	    state->params.wots_par.keysize +
24038032Speter	    state->h * state->n;
24138032Speter	return 0;
24238032Speter}
24338032Speter
24438032Spetersize_t
24538032Spetersshkey_xmss_pklen(const struct sshkey *key)
24638032Speter{
24738032Speter	struct ssh_xmss_state *state = key->xmss_state;
24838032Speter
24938032Speter	if (state == NULL)
25038032Speter		return 0;
25138032Speter	return state->n * 2;
25238032Speter}
25338032Speter
25438032Spetersize_t
25538032Spetersshkey_xmss_sklen(const struct sshkey *key)
25638032Speter{
25738032Speter	struct ssh_xmss_state *state = key->xmss_state;
25838032Speter
25938032Speter	if (state == NULL)
26038032Speter		return 0;
26138032Speter	return state->n * 4 + 4;
26238032Speter}
26338032Speter
26438032Speterint
26538032Spetersshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
26638032Speter{
26738032Speter	struct ssh_xmss_state *state = k->xmss_state;
26838032Speter	const struct sshcipher *cipher;
26938032Speter	size_t keylen = 0, ivlen = 0;
27038032Speter
27138032Speter	if (state == NULL)
27238032Speter		return SSH_ERR_INVALID_ARGUMENT;
27338032Speter	if ((cipher = cipher_by_name(ciphername)) == NULL)
27438032Speter		return SSH_ERR_INTERNAL_ERROR;
27538032Speter	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
27638032Speter		return SSH_ERR_ALLOC_FAIL;
27738032Speter	keylen = cipher_keylen(cipher);
27838032Speter	ivlen = cipher_ivlen(cipher);
27938032Speter	state->enc_keyiv_len = keylen + ivlen;
28038032Speter	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
28138032Speter		free(state->enc_ciphername);
28238032Speter		state->enc_ciphername = NULL;
28338032Speter		return SSH_ERR_ALLOC_FAIL;
28438032Speter	}
28538032Speter	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
28638032Speter	return 0;
28738032Speter}
28838032Speter
28938032Speterint
29038032Spetersshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
29138032Speter{
29238032Speter	struct ssh_xmss_state *state = k->xmss_state;
29338032Speter	int r;
29438032Speter
29538032Speter	if (state == NULL || state->enc_keyiv == NULL ||
29638032Speter	    state->enc_ciphername == NULL)
29738032Speter		return SSH_ERR_INVALID_ARGUMENT;
29838032Speter	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
29938032Speter	    (r = sshbuf_put_string(b, state->enc_keyiv,
30038032Speter	    state->enc_keyiv_len)) != 0)
30138032Speter		return r;
30238032Speter	return 0;
30338032Speter}
30438032Speter
30538032Speterint
30638032Spetersshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
30738032Speter{
30838032Speter	struct ssh_xmss_state *state = k->xmss_state;
30938032Speter	size_t len;
31038032Speter	int r;
31138032Speter
31238032Speter	if (state == NULL)
31338032Speter		return SSH_ERR_INVALID_ARGUMENT;
31438032Speter	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
31538032Speter	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
31638032Speter		return r;
31738032Speter	state->enc_keyiv_len = len;
31838032Speter	return 0;
31938032Speter}
32038032Speter
32138032Speterint
32238032Spetersshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
32338032Speter    enum sshkey_serialize_rep opts)
32438032Speter{
32538032Speter	struct ssh_xmss_state *state = k->xmss_state;
32638032Speter	u_char have_info = 1;
32738032Speter	u_int32_t idx;
32838032Speter	int r;
32938032Speter
33038032Speter	if (state == NULL)
33138032Speter		return SSH_ERR_INVALID_ARGUMENT;
33238032Speter	if (opts != SSHKEY_SERIALIZE_INFO)
33338032Speter		return 0;
33438032Speter	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
33538032Speter	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
33638032Speter	    (r = sshbuf_put_u32(b, idx)) != 0 ||
33738032Speter	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
33838032Speter		return r;
33938032Speter	return 0;
34038032Speter}
34138032Speter
34238032Speterint
34338032Spetersshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
34438032Speter{
34538032Speter	struct ssh_xmss_state *state = k->xmss_state;
34638032Speter	u_char have_info;
34738032Speter	int r;
34838032Speter
34938032Speter	if (state == NULL)
35038032Speter		return SSH_ERR_INVALID_ARGUMENT;
35138032Speter	/* optional */
35238032Speter	if (sshbuf_len(b) == 0)
35338032Speter		return 0;
35438032Speter	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
35538032Speter		return r;
35638032Speter	if (have_info != 1)
35738032Speter		return SSH_ERR_INVALID_ARGUMENT;
35838032Speter	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
35938032Speter	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
36038032Speter		return r;
36138032Speter	return 0;
36238032Speter}
36338032Speter
36438032Speterint
36538032Spetersshkey_xmss_generate_private_key(struct sshkey *k, int bits)
36638032Speter{
36738032Speter	int r;
36838032Speter	const char *name;
36938032Speter
37038032Speter	if (bits == 10) {
37138032Speter		name = XMSS_SHA2_256_W16_H10_NAME;
37238032Speter	} else if (bits == 16) {
37338032Speter		name = XMSS_SHA2_256_W16_H16_NAME;
37438032Speter	} else if (bits == 20) {
37538032Speter		name = XMSS_SHA2_256_W16_H20_NAME;
37638032Speter	} else {
37738032Speter		name = XMSS_DEFAULT_NAME;
37838032Speter	}
37938032Speter	if ((r = sshkey_xmss_init(k, name)) != 0 ||
38038032Speter	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
38138032Speter	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
38238032Speter		return r;
38338032Speter	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
38438032Speter	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
38538032Speter		return SSH_ERR_ALLOC_FAIL;
38638032Speter	}
38738032Speter	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
38838032Speter	    sshkey_xmss_params(k));
38938032Speter	return 0;
39038032Speter}
39138032Speter
39238032Speterint
39338032Spetersshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
39438032Speter    int *have_file, int printerror)
39538032Speter{
39638032Speter	struct sshbuf *b = NULL, *enc = NULL;
39738032Speter	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
39838032Speter	u_int32_t len;
39938032Speter	unsigned char buf[4], *data = NULL;
40038032Speter
40138032Speter	*have_file = 0;
40238032Speter	if ((fd = open(filename, O_RDONLY)) >= 0) {
40338032Speter		*have_file = 1;
40438032Speter		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
40538032Speter			PRINT("corrupt state file: %s", filename);
40638032Speter			goto done;
40738032Speter		}
40838032Speter		len = PEEK_U32(buf);
40938032Speter		if ((data = calloc(len, 1)) == NULL) {
41038032Speter			ret = SSH_ERR_ALLOC_FAIL;
41138032Speter			goto done;
41238032Speter		}
41338032Speter		if (atomicio(read, fd, data, len) != len) {
41438032Speter			PRINT("cannot read blob: %s", filename);
41538032Speter			goto done;
41638032Speter		}
41738032Speter		if ((enc = sshbuf_from(data, len)) == NULL) {
41838032Speter			ret = SSH_ERR_ALLOC_FAIL;
41938032Speter			goto done;
42038032Speter		}
42138032Speter		sshkey_xmss_free_bds(k);
42238032Speter		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
42338032Speter			ret = r;
42438032Speter			goto done;
42538032Speter		}
42638032Speter		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
42738032Speter			ret = r;
42838032Speter			goto done;
42938032Speter		}
43038032Speter		ret = 0;
43138032Speter	}
43238032Speterdone:
43338032Speter	if (fd != -1)
43438032Speter		close(fd);
43538032Speter	free(data);
43638032Speter	sshbuf_free(enc);
43738032Speter	sshbuf_free(b);
43838032Speter	return ret;
43938032Speter}
44038032Speter
44138032Speterint
44238032Spetersshkey_xmss_get_state(const struct sshkey *k, int printerror)
44338032Speter{
44438032Speter	struct ssh_xmss_state *state = k->xmss_state;
44538032Speter	u_int32_t idx = 0;
44638032Speter	char *filename = NULL;
44738032Speter	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
44838032Speter	int lockfd = -1, have_state = 0, have_ostate = 0, tries = 0;
44938032Speter	int ret = SSH_ERR_INVALID_ARGUMENT, r;
45038032Speter
45138032Speter	if (state == NULL)
45238032Speter		goto done;
45338032Speter	/*
45438032Speter	 * If maxidx is set, then we are allowed a limited number
45538032Speter	 * of signatures, but don't need to access the disk.
45638032Speter	 * Otherwise we need to deal with the on-disk state.
45738032Speter	 */
45838032Speter	if (state->maxidx) {
45938032Speter		/* xmss_sk always contains the current state */
46038032Speter		idx = PEEK_U32(k->xmss_sk);
46138032Speter		if (idx < state->maxidx) {
46238032Speter			state->allow_update = 1;
46338032Speter			return 0;
46438032Speter		}
46538032Speter		return SSH_ERR_INVALID_ARGUMENT;
46638032Speter	}
46738032Speter	if ((filename = k->xmss_filename) == NULL)
46838032Speter		goto done;
46938032Speter	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
47038032Speter	    asprintf(&statefile, "%s.state", filename) == -1 ||
47138032Speter	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
47238032Speter		ret = SSH_ERR_ALLOC_FAIL;
47338032Speter		goto done;
47438032Speter	}
47538032Speter	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
47638032Speter		ret = SSH_ERR_SYSTEM_ERROR;
47738032Speter		PRINT("cannot open/create: %s", lockfile);
47838032Speter		goto done;
47938032Speter	}
48038032Speter	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
48138032Speter		if (errno != EWOULDBLOCK) {
48238032Speter			ret = SSH_ERR_SYSTEM_ERROR;
48338032Speter			PRINT("cannot lock: %s", lockfile);
48438032Speter			goto done;
48538032Speter		}
48638032Speter		if (++tries > 10) {
48738032Speter			ret = SSH_ERR_SYSTEM_ERROR;
48838032Speter			PRINT("giving up on: %s", lockfile);
48938032Speter			goto done;
49038032Speter		}
49138032Speter		usleep(1000*100*tries);
49238032Speter	}
49338032Speter	/* XXX no longer const */
49438032Speter	if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
49538032Speter	    statefile, &have_state, printerror)) != 0) {
49638032Speter		if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k),
49738032Speter		    ostatefile, &have_ostate, printerror)) == 0) {
49838032Speter			state->allow_update = 1;
49938032Speter			r = sshkey_xmss_forward_state(k, 1);
50038032Speter			state->idx = PEEK_U32(k->xmss_sk);
50138032Speter			state->allow_update = 0;
50238032Speter		}
50338032Speter	}
50438032Speter	if (!have_state && !have_ostate) {
50538032Speter		/* check that bds state is initialized */
50638032Speter		if (state->bds.auth == NULL)
50738032Speter			goto done;
50838032Speter		PRINT("start from scratch idx 0: %u", state->idx);
50938032Speter	} else if (r != 0) {
51038032Speter		ret = r;
51138032Speter		goto done;
51238032Speter	}
51338032Speter	if (state->idx + 1 < state->idx) {
51438032Speter		PRINT("state wrap: %u", state->idx);
51538032Speter		goto done;
51638032Speter	}
51738032Speter	state->have_state = have_state;
51838032Speter	state->lockfd = lockfd;
51938032Speter	state->allow_update = 1;
52038032Speter	lockfd = -1;
52138032Speter	ret = 0;
52238032Speterdone:
52338032Speter	if (lockfd != -1)
52438032Speter		close(lockfd);
52538032Speter	free(lockfile);
52638032Speter	free(statefile);
52738032Speter	free(ostatefile);
52838032Speter	return ret;
52938032Speter}
53038032Speter
53138032Speterint
53238032Spetersshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
53338032Speter{
53438032Speter	struct ssh_xmss_state *state = k->xmss_state;
53538032Speter	u_char *sig = NULL;
53638032Speter	size_t required_siglen;
53738032Speter	unsigned long long smlen;
53838032Speter	u_char data;
53938032Speter	int ret, r;
54038032Speter
54138032Speter	if (state == NULL || !state->allow_update)
54238032Speter		return SSH_ERR_INVALID_ARGUMENT;
54338032Speter	if (reserve == 0)
54438032Speter		return SSH_ERR_INVALID_ARGUMENT;
54538032Speter	if (state->idx + reserve <= state->idx)
54638032Speter		return SSH_ERR_INVALID_ARGUMENT;
54738032Speter	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
54838032Speter		return r;
54938032Speter	if ((sig = malloc(required_siglen)) == NULL)
55038032Speter		return SSH_ERR_ALLOC_FAIL;
55138032Speter	while (reserve-- > 0) {
55238032Speter		state->idx = PEEK_U32(k->xmss_sk);
55338032Speter		smlen = required_siglen;
55438032Speter		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
55538032Speter		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
55638032Speter			r = SSH_ERR_INVALID_ARGUMENT;
55738032Speter			break;
558		}
559	}
560	free(sig);
561	return r;
562}
563
564int
565sshkey_xmss_update_state(const struct sshkey *k, int printerror)
566{
567	struct ssh_xmss_state *state = k->xmss_state;
568	struct sshbuf *b = NULL, *enc = NULL;
569	u_int32_t idx = 0;
570	unsigned char buf[4];
571	char *filename = NULL;
572	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
573	int fd = -1;
574	int ret = SSH_ERR_INVALID_ARGUMENT;
575
576	if (state == NULL || !state->allow_update)
577		return ret;
578	if (state->maxidx) {
579		/* no update since the number of signatures is limited */
580		ret = 0;
581		goto done;
582	}
583	idx = PEEK_U32(k->xmss_sk);
584	if (idx == state->idx) {
585		/* no signature happened, no need to update */
586		ret = 0;
587		goto done;
588	} else if (idx != state->idx + 1) {
589		PRINT("more than one signature happened: idx %u state %u",
590		    idx, state->idx);
591		goto done;
592	}
593	state->idx = idx;
594	if ((filename = k->xmss_filename) == NULL)
595		goto done;
596	if (asprintf(&statefile, "%s.state", filename) == -1 ||
597	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
598	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
599		ret = SSH_ERR_ALLOC_FAIL;
600		goto done;
601	}
602	unlink(nstatefile);
603	if ((b = sshbuf_new()) == NULL) {
604		ret = SSH_ERR_ALLOC_FAIL;
605		goto done;
606	}
607	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
608		PRINT("SERLIALIZE FAILED: %d", ret);
609		goto done;
610	}
611	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
612		PRINT("ENCRYPT FAILED: %d", ret);
613		goto done;
614	}
615	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
616		ret = SSH_ERR_SYSTEM_ERROR;
617		PRINT("open new state file: %s", nstatefile);
618		goto done;
619	}
620	POKE_U32(buf, sshbuf_len(enc));
621	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
622		ret = SSH_ERR_SYSTEM_ERROR;
623		PRINT("write new state file hdr: %s", nstatefile);
624		close(fd);
625		goto done;
626	}
627	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
628	    sshbuf_len(enc)) {
629		ret = SSH_ERR_SYSTEM_ERROR;
630		PRINT("write new state file data: %s", nstatefile);
631		close(fd);
632		goto done;
633	}
634	if (fsync(fd) == -1) {
635		ret = SSH_ERR_SYSTEM_ERROR;
636		PRINT("sync new state file: %s", nstatefile);
637		close(fd);
638		goto done;
639	}
640	if (close(fd) == -1) {
641		ret = SSH_ERR_SYSTEM_ERROR;
642		PRINT("close new state file: %s", nstatefile);
643		goto done;
644	}
645	if (state->have_state) {
646		unlink(ostatefile);
647		if (link(statefile, ostatefile)) {
648			ret = SSH_ERR_SYSTEM_ERROR;
649			PRINT("backup state %s to %s", statefile, ostatefile);
650			goto done;
651		}
652	}
653	if (rename(nstatefile, statefile) == -1) {
654		ret = SSH_ERR_SYSTEM_ERROR;
655		PRINT("rename %s to %s", nstatefile, statefile);
656		goto done;
657	}
658	ret = 0;
659done:
660	if (state->lockfd != -1) {
661		close(state->lockfd);
662		state->lockfd = -1;
663	}
664	if (nstatefile)
665		unlink(nstatefile);
666	free(statefile);
667	free(ostatefile);
668	free(nstatefile);
669	sshbuf_free(b);
670	sshbuf_free(enc);
671	return ret;
672}
673
674int
675sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
676{
677	struct ssh_xmss_state *state = k->xmss_state;
678	treehash_inst *th;
679	u_int32_t i, node;
680	int r;
681
682	if (state == NULL)
683		return SSH_ERR_INVALID_ARGUMENT;
684	if (state->stack == NULL)
685		return SSH_ERR_INVALID_ARGUMENT;
686	state->stackoffset = state->bds.stackoffset;	/* copy back */
687	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
688	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
689	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
690	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
691	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
692	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
693	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
694	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
695	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
696	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
697		return r;
698	for (i = 0; i < num_treehash(state); i++) {
699		th = &state->treehash[i];
700		node = th->node - state->th_nodes;
701		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
702		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
703		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
704		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
705		    (r = sshbuf_put_u32(b, node)) != 0)
706			return r;
707	}
708	return 0;
709}
710
711int
712sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
713    enum sshkey_serialize_rep opts)
714{
715	struct ssh_xmss_state *state = k->xmss_state;
716	int r = SSH_ERR_INVALID_ARGUMENT;
717	u_char have_stack, have_filename, have_enc;
718
719	if (state == NULL)
720		return SSH_ERR_INVALID_ARGUMENT;
721	if ((r = sshbuf_put_u8(b, opts)) != 0)
722		return r;
723	switch (opts) {
724	case SSHKEY_SERIALIZE_STATE:
725		r = sshkey_xmss_serialize_state(k, b);
726		break;
727	case SSHKEY_SERIALIZE_FULL:
728		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
729			return r;
730		r = sshkey_xmss_serialize_state(k, b);
731		break;
732	case SSHKEY_SERIALIZE_SHIELD:
733		/* all of stack/filename/enc are optional */
734		have_stack = state->stack != NULL;
735		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
736			return r;
737		if (have_stack) {
738			state->idx = PEEK_U32(k->xmss_sk);	/* update */
739			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
740				return r;
741		}
742		have_filename = k->xmss_filename != NULL;
743		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
744			return r;
745		if (have_filename &&
746		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
747			return r;
748		have_enc = state->enc_keyiv != NULL;
749		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
750			return r;
751		if (have_enc &&
752		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
753			return r;
754		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
755		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
756			return r;
757		break;
758	case SSHKEY_SERIALIZE_DEFAULT:
759		r = 0;
760		break;
761	default:
762		r = SSH_ERR_INVALID_ARGUMENT;
763		break;
764	}
765	return r;
766}
767
768int
769sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
770{
771	struct ssh_xmss_state *state = k->xmss_state;
772	treehash_inst *th;
773	u_int32_t i, lh, node;
774	size_t ls, lsl, la, lk, ln, lr;
775	char *magic;
776	int r = SSH_ERR_INTERNAL_ERROR;
777
778	if (state == NULL)
779		return SSH_ERR_INVALID_ARGUMENT;
780	if (k->xmss_sk == NULL)
781		return SSH_ERR_INVALID_ARGUMENT;
782	if ((state->treehash = calloc(num_treehash(state),
783	    sizeof(treehash_inst))) == NULL)
784		return SSH_ERR_ALLOC_FAIL;
785	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
786	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
787	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
788	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
789	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
790	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
791	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
792	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
793	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
794	    (r = sshbuf_get_u32(b, &lh)) != 0)
795		goto out;
796	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
797		r = SSH_ERR_INVALID_ARGUMENT;
798		goto out;
799	}
800	/* XXX check stackoffset */
801	if (ls != num_stack(state) ||
802	    lsl != num_stacklevels(state) ||
803	    la != num_auth(state) ||
804	    lk != num_keep(state) ||
805	    ln != num_th_nodes(state) ||
806	    lr != num_retain(state) ||
807	    lh != num_treehash(state)) {
808		r = SSH_ERR_INVALID_ARGUMENT;
809		goto out;
810	}
811	for (i = 0; i < num_treehash(state); i++) {
812		th = &state->treehash[i];
813		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
814		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
815		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
816		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
817		    (r = sshbuf_get_u32(b, &node)) != 0)
818			goto out;
819		if (node < num_th_nodes(state))
820			th->node = &state->th_nodes[node];
821	}
822	POKE_U32(k->xmss_sk, state->idx);
823	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
824	    state->stacklevels, state->auth, state->keep, state->treehash,
825	    state->retain, 0);
826	/* success */
827	r = 0;
828 out:
829	free(magic);
830	return r;
831}
832
833int
834sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
835{
836	struct ssh_xmss_state *state = k->xmss_state;
837	enum sshkey_serialize_rep opts;
838	u_char have_state, have_stack, have_filename, have_enc;
839	int r;
840
841	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
842		return r;
843
844	opts = have_state;
845	switch (opts) {
846	case SSHKEY_SERIALIZE_DEFAULT:
847		r = 0;
848		break;
849	case SSHKEY_SERIALIZE_SHIELD:
850		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
851			return r;
852		if (have_stack &&
853		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
854			return r;
855		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
856			return r;
857		if (have_filename &&
858		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
859			return r;
860		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
861			return r;
862		if (have_enc &&
863		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
864			return r;
865		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
866		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
867			return r;
868		break;
869	case SSHKEY_SERIALIZE_STATE:
870		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
871			return r;
872		break;
873	case SSHKEY_SERIALIZE_FULL:
874		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
875		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
876			return r;
877		break;
878	default:
879		r = SSH_ERR_INVALID_FORMAT;
880		break;
881	}
882	return r;
883}
884
885int
886sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
887   struct sshbuf **retp)
888{
889	struct ssh_xmss_state *state = k->xmss_state;
890	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
891	struct sshcipher_ctx *ciphercontext = NULL;
892	const struct sshcipher *cipher;
893	u_char *cp, *key, *iv = NULL;
894	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
895	int r = SSH_ERR_INTERNAL_ERROR;
896
897	if (retp != NULL)
898		*retp = NULL;
899	if (state == NULL ||
900	    state->enc_keyiv == NULL ||
901	    state->enc_ciphername == NULL)
902		return SSH_ERR_INTERNAL_ERROR;
903	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
904		r = SSH_ERR_INTERNAL_ERROR;
905		goto out;
906	}
907	blocksize = cipher_blocksize(cipher);
908	keylen = cipher_keylen(cipher);
909	ivlen = cipher_ivlen(cipher);
910	authlen = cipher_authlen(cipher);
911	if (state->enc_keyiv_len != keylen + ivlen) {
912		r = SSH_ERR_INVALID_FORMAT;
913		goto out;
914	}
915	key = state->enc_keyiv;
916	if ((encrypted = sshbuf_new()) == NULL ||
917	    (encoded = sshbuf_new()) == NULL ||
918	    (padded = sshbuf_new()) == NULL ||
919	    (iv = malloc(ivlen)) == NULL) {
920		r = SSH_ERR_ALLOC_FAIL;
921		goto out;
922	}
923
924	/* replace first 4 bytes of IV with index to ensure uniqueness */
925	memcpy(iv, key + keylen, ivlen);
926	POKE_U32(iv, state->idx);
927
928	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
929	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
930		goto out;
931
932	/* padded state will be encrypted */
933	if ((r = sshbuf_putb(padded, b)) != 0)
934		goto out;
935	i = 0;
936	while (sshbuf_len(padded) % blocksize) {
937		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
938			goto out;
939	}
940	encrypted_len = sshbuf_len(padded);
941
942	/* header including the length of state is used as AAD */
943	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
944		goto out;
945	aadlen = sshbuf_len(encoded);
946
947	/* concat header and state */
948	if ((r = sshbuf_putb(encoded, padded)) != 0)
949		goto out;
950
951	/* reserve space for encryption of encoded data plus auth tag */
952	/* encrypt at offset addlen */
953	if ((r = sshbuf_reserve(encrypted,
954	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
955	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
956	    iv, ivlen, 1)) != 0 ||
957	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
958	    encrypted_len, aadlen, authlen)) != 0)
959		goto out;
960
961	/* success */
962	r = 0;
963 out:
964	if (retp != NULL) {
965		*retp = encrypted;
966		encrypted = NULL;
967	}
968	sshbuf_free(padded);
969	sshbuf_free(encoded);
970	sshbuf_free(encrypted);
971	cipher_free(ciphercontext);
972	free(iv);
973	return r;
974}
975
976int
977sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
978   struct sshbuf **retp)
979{
980	struct ssh_xmss_state *state = k->xmss_state;
981	struct sshbuf *copy = NULL, *decrypted = NULL;
982	struct sshcipher_ctx *ciphercontext = NULL;
983	const struct sshcipher *cipher = NULL;
984	u_char *key, *iv = NULL, *dp;
985	size_t keylen, ivlen, authlen, aadlen;
986	u_int blocksize, encrypted_len, index;
987	int r = SSH_ERR_INTERNAL_ERROR;
988
989	if (retp != NULL)
990		*retp = NULL;
991	if (state == NULL ||
992	    state->enc_keyiv == NULL ||
993	    state->enc_ciphername == NULL)
994		return SSH_ERR_INTERNAL_ERROR;
995	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
996		r = SSH_ERR_INVALID_FORMAT;
997		goto out;
998	}
999	blocksize = cipher_blocksize(cipher);
1000	keylen = cipher_keylen(cipher);
1001	ivlen = cipher_ivlen(cipher);
1002	authlen = cipher_authlen(cipher);
1003	if (state->enc_keyiv_len != keylen + ivlen) {
1004		r = SSH_ERR_INTERNAL_ERROR;
1005		goto out;
1006	}
1007	key = state->enc_keyiv;
1008
1009	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1010	    (decrypted = sshbuf_new()) == NULL ||
1011	    (iv = malloc(ivlen)) == NULL) {
1012		r = SSH_ERR_ALLOC_FAIL;
1013		goto out;
1014	}
1015
1016	/* check magic */
1017	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1018	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1019		r = SSH_ERR_INVALID_FORMAT;
1020		goto out;
1021	}
1022	/* parse public portion */
1023	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1024	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1025	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1026		goto out;
1027
1028	/* check size of encrypted key blob */
1029	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1030		r = SSH_ERR_INVALID_FORMAT;
1031		goto out;
1032	}
1033	/* check that an appropriate amount of auth data is present */
1034	if (sshbuf_len(encoded) < authlen ||
1035	    sshbuf_len(encoded) - authlen < encrypted_len) {
1036		r = SSH_ERR_INVALID_FORMAT;
1037		goto out;
1038	}
1039
1040	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1041
1042	/* replace first 4 bytes of IV with index to ensure uniqueness */
1043	memcpy(iv, key + keylen, ivlen);
1044	POKE_U32(iv, index);
1045
1046	/* decrypt private state of key */
1047	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1048	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1049	    iv, ivlen, 0)) != 0 ||
1050	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1051	    encrypted_len, aadlen, authlen)) != 0)
1052		goto out;
1053
1054	/* there should be no trailing data */
1055	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1056		goto out;
1057	if (sshbuf_len(encoded) != 0) {
1058		r = SSH_ERR_INVALID_FORMAT;
1059		goto out;
1060	}
1061
1062	/* remove AAD */
1063	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1064		goto out;
1065	/* XXX encrypted includes unchecked padding */
1066
1067	/* success */
1068	r = 0;
1069	if (retp != NULL) {
1070		*retp = decrypted;
1071		decrypted = NULL;
1072	}
1073 out:
1074	cipher_free(ciphercontext);
1075	sshbuf_free(copy);
1076	sshbuf_free(decrypted);
1077	free(iv);
1078	return r;
1079}
1080
1081u_int32_t
1082sshkey_xmss_signatures_left(const struct sshkey *k)
1083{
1084	struct ssh_xmss_state *state = k->xmss_state;
1085	u_int32_t idx;
1086
1087	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1088	    state->maxidx) {
1089		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1090		if (idx < state->maxidx)
1091			return state->maxidx - idx;
1092	}
1093	return 0;
1094}
1095
1096int
1097sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1098{
1099	struct ssh_xmss_state *state = k->xmss_state;
1100
1101	if (sshkey_type_plain(k->type) != KEY_XMSS)
1102		return SSH_ERR_INVALID_ARGUMENT;
1103	if (maxsign == 0)
1104		return 0;
1105	if (state->idx + maxsign < state->idx)
1106		return SSH_ERR_INVALID_ARGUMENT;
1107	state->maxidx = state->idx + maxsign;
1108	return 0;
1109}
1110