1/* $NetBSD: sshkey-xmss.c,v 1.10 2023/08/03 07:59:32 mrg Exp $ */ 2/* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */ 3/* 4 * Copyright (c) 2017 Markus Friedl. All rights reserved. 5 * 6 * Redistribution and use in source and binary forms, with or without 7 * modification, are permitted provided that the following conditions 8 * are met: 9 * 1. Redistributions of source code must retain the above copyright 10 * notice, this list of conditions and the following disclaimer. 11 * 2. Redistributions in binary form must reproduce the above copyright 12 * notice, this list of conditions and the following disclaimer in the 13 * documentation and/or other materials provided with the distribution. 14 * 15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 17 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 18 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 19 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 20 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 24 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 */ 26#include "includes.h" 27__RCSID("$NetBSD: sshkey-xmss.c,v 1.10 2023/08/03 07:59:32 mrg Exp $"); 28 29#include <sys/types.h> 30#include <sys/uio.h> 31 32#include <stdio.h> 33#include <string.h> 34#include <unistd.h> 35#include <fcntl.h> 36#include <errno.h> 37 38#include "ssh2.h" 39#include "ssherr.h" 40#include "sshbuf.h" 41#include "cipher.h" 42#include "sshkey.h" 43#include "sshkey-xmss.h" 44#include "atomicio.h" 45#include "log.h" 46 47#include "xmss_fast.h" 48 49/* opaque internal XMSS state */ 50#define XMSS_MAGIC "xmss-state-v1" 51#define XMSS_CIPHERNAME "aes256-gcm@openssh.com" 52struct ssh_xmss_state { 53 xmss_params params; 54 u_int32_t n, w, h, k; 55 56 bds_state bds; 57 u_char *stack; 58 u_int32_t stackoffset; 59 u_char *stacklevels; 60 u_char *auth; 61 u_char *keep; 62 u_char *th_nodes; 63 u_char *retain; 64 treehash_inst *treehash; 65 66 u_int32_t idx; /* state read from file */ 67 u_int32_t maxidx; /* restricted # of signatures */ 68 int have_state; /* .state file exists */ 69 int lockfd; /* locked in sshkey_xmss_get_state() */ 70 u_char allow_update; /* allow sshkey_xmss_update_state() */ 71 char *enc_ciphername;/* encrypt state with cipher */ 72 u_char *enc_keyiv; /* encrypt state with key */ 73 u_int32_t enc_keyiv_len; /* length of enc_keyiv */ 74}; 75 76int sshkey_xmss_init_bds_state(struct sshkey *); 77int sshkey_xmss_init_enc_key(struct sshkey *, const char *); 78void sshkey_xmss_free_bds(struct sshkey *); 79int sshkey_xmss_get_state_from_file(struct sshkey *, const char *, 80 int *, int); 81int sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *, 82 struct sshbuf **); 83int sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *, 84 struct sshbuf **); 85int sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *); 86int sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *); 87 88#define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \ 89 0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0) 90 91int 92sshkey_xmss_init(struct sshkey *key, const char *name) 93{ 94 struct ssh_xmss_state *state; 95 96 if (key->xmss_state != NULL) 97 return SSH_ERR_INVALID_FORMAT; 98 if (name == NULL) 99 return SSH_ERR_INVALID_FORMAT; 100 state = calloc(sizeof(struct ssh_xmss_state), 1); 101 if (state == NULL) 102 return SSH_ERR_ALLOC_FAIL; 103 if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) { 104 state->n = 32; 105 state->w = 16; 106 state->h = 10; 107 } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) { 108 state->n = 32; 109 state->w = 16; 110 state->h = 16; 111 } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) { 112 state->n = 32; 113 state->w = 16; 114 state->h = 20; 115 } else { 116 free(state); 117 return SSH_ERR_KEY_TYPE_UNKNOWN; 118 } 119 if ((key->xmss_name = strdup(name)) == NULL) { 120 free(state); 121 return SSH_ERR_ALLOC_FAIL; 122 } 123 state->k = 2; /* XXX hardcoded */ 124 state->lockfd = -1; 125 if (xmss_set_params(&state->params, state->n, state->h, state->w, 126 state->k) != 0) { 127 free(state); 128 return SSH_ERR_INVALID_FORMAT; 129 } 130 key->xmss_state = state; 131 return 0; 132} 133 134void 135sshkey_xmss_free_state(struct sshkey *key) 136{ 137 struct ssh_xmss_state *state = key->xmss_state; 138 139 sshkey_xmss_free_bds(key); 140 if (state) { 141 if (state->enc_keyiv) { 142 explicit_bzero(state->enc_keyiv, state->enc_keyiv_len); 143 free(state->enc_keyiv); 144 } 145 free(state->enc_ciphername); 146 free(state); 147 } 148 key->xmss_state = NULL; 149} 150 151#define SSH_XMSS_K2_MAGIC "k=2" 152#define num_stack(x) ((x->h+1)*(x->n)) 153#define num_stacklevels(x) (x->h+1) 154#define num_auth(x) ((x->h)*(x->n)) 155#define num_keep(x) ((x->h >> 1)*(x->n)) 156#define num_th_nodes(x) ((x->h - x->k)*(x->n)) 157#define num_retain(x) (((1ULL << x->k) - x->k - 1) * (x->n)) 158#define num_treehash(x) ((x->h) - (x->k)) 159 160int 161sshkey_xmss_init_bds_state(struct sshkey *key) 162{ 163 struct ssh_xmss_state *state = key->xmss_state; 164 u_int32_t i; 165 166 state->stackoffset = 0; 167 if ((state->stack = calloc(num_stack(state), 1)) == NULL || 168 (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL || 169 (state->auth = calloc(num_auth(state), 1)) == NULL || 170 (state->keep = calloc(num_keep(state), 1)) == NULL || 171 (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL || 172 (state->retain = calloc(num_retain(state), 1)) == NULL || 173 (state->treehash = calloc(num_treehash(state), 174 sizeof(treehash_inst))) == NULL) { 175 sshkey_xmss_free_bds(key); 176 return SSH_ERR_ALLOC_FAIL; 177 } 178 for (i = 0; i < state->h - state->k; i++) 179 state->treehash[i].node = &state->th_nodes[state->n*i]; 180 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset, 181 state->stacklevels, state->auth, state->keep, state->treehash, 182 state->retain, 0); 183 return 0; 184} 185 186void 187sshkey_xmss_free_bds(struct sshkey *key) 188{ 189 struct ssh_xmss_state *state = key->xmss_state; 190 191 if (state == NULL) 192 return; 193 free(state->stack); 194 free(state->stacklevels); 195 free(state->auth); 196 free(state->keep); 197 free(state->th_nodes); 198 free(state->retain); 199 free(state->treehash); 200 state->stack = NULL; 201 state->stacklevels = NULL; 202 state->auth = NULL; 203 state->keep = NULL; 204 state->th_nodes = NULL; 205 state->retain = NULL; 206 state->treehash = NULL; 207} 208 209void * 210sshkey_xmss_params(const struct sshkey *key) 211{ 212 struct ssh_xmss_state *state = key->xmss_state; 213 214 if (state == NULL) 215 return NULL; 216 return &state->params; 217} 218 219void * 220sshkey_xmss_bds_state(const struct sshkey *key) 221{ 222 struct ssh_xmss_state *state = key->xmss_state; 223 224 if (state == NULL) 225 return NULL; 226 return &state->bds; 227} 228 229int 230sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp) 231{ 232 struct ssh_xmss_state *state = key->xmss_state; 233 234 if (lenp == NULL) 235 return SSH_ERR_INVALID_ARGUMENT; 236 if (state == NULL) 237 return SSH_ERR_INVALID_FORMAT; 238 *lenp = 4 + state->n + 239 state->params.wots_par.keysize + 240 state->h * state->n; 241 return 0; 242} 243 244size_t 245sshkey_xmss_pklen(const struct sshkey *key) 246{ 247 struct ssh_xmss_state *state = key->xmss_state; 248 249 if (state == NULL) 250 return 0; 251 return state->n * 2; 252} 253 254size_t 255sshkey_xmss_sklen(const struct sshkey *key) 256{ 257 struct ssh_xmss_state *state = key->xmss_state; 258 259 if (state == NULL) 260 return 0; 261 return state->n * 4 + 4; 262} 263 264int 265sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername) 266{ 267 struct ssh_xmss_state *state = k->xmss_state; 268 const struct sshcipher *cipher; 269 size_t keylen = 0, ivlen = 0; 270 271 if (state == NULL) 272 return SSH_ERR_INVALID_ARGUMENT; 273 if ((cipher = cipher_by_name(ciphername)) == NULL) 274 return SSH_ERR_INTERNAL_ERROR; 275 if ((state->enc_ciphername = strdup(ciphername)) == NULL) 276 return SSH_ERR_ALLOC_FAIL; 277 keylen = cipher_keylen(cipher); 278 ivlen = cipher_ivlen(cipher); 279 state->enc_keyiv_len = keylen + ivlen; 280 if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) { 281 free(state->enc_ciphername); 282 state->enc_ciphername = NULL; 283 return SSH_ERR_ALLOC_FAIL; 284 } 285 arc4random_buf(state->enc_keyiv, state->enc_keyiv_len); 286 return 0; 287} 288 289int 290sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b) 291{ 292 struct ssh_xmss_state *state = k->xmss_state; 293 int r; 294 295 if (state == NULL || state->enc_keyiv == NULL || 296 state->enc_ciphername == NULL) 297 return SSH_ERR_INVALID_ARGUMENT; 298 if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 || 299 (r = sshbuf_put_string(b, state->enc_keyiv, 300 state->enc_keyiv_len)) != 0) 301 return r; 302 return 0; 303} 304 305int 306sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b) 307{ 308 struct ssh_xmss_state *state = k->xmss_state; 309 size_t len; 310 int r; 311 312 if (state == NULL) 313 return SSH_ERR_INVALID_ARGUMENT; 314 if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 || 315 (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0) 316 return r; 317 state->enc_keyiv_len = len; 318 return 0; 319} 320 321int 322sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b, 323 enum sshkey_serialize_rep opts) 324{ 325 struct ssh_xmss_state *state = k->xmss_state; 326 u_char have_info = 1; 327 u_int32_t idx; 328 int r; 329 330 if (state == NULL) 331 return SSH_ERR_INVALID_ARGUMENT; 332 if (opts != SSHKEY_SERIALIZE_INFO) 333 return 0; 334 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx; 335 if ((r = sshbuf_put_u8(b, have_info)) != 0 || 336 (r = sshbuf_put_u32(b, idx)) != 0 || 337 (r = sshbuf_put_u32(b, state->maxidx)) != 0) 338 return r; 339 return 0; 340} 341 342int 343sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b) 344{ 345 struct ssh_xmss_state *state = k->xmss_state; 346 u_char have_info; 347 int r; 348 349 if (state == NULL) 350 return SSH_ERR_INVALID_ARGUMENT; 351 /* optional */ 352 if (sshbuf_len(b) == 0) 353 return 0; 354 if ((r = sshbuf_get_u8(b, &have_info)) != 0) 355 return r; 356 if (have_info != 1) 357 return SSH_ERR_INVALID_ARGUMENT; 358 if ((r = sshbuf_get_u32(b, &state->idx)) != 0 || 359 (r = sshbuf_get_u32(b, &state->maxidx)) != 0) 360 return r; 361 return 0; 362} 363 364int 365sshkey_xmss_generate_private_key(struct sshkey *k, int bits) 366{ 367 int r; 368 const char *name; 369 370 if (bits == 10) { 371 name = XMSS_SHA2_256_W16_H10_NAME; 372 } else if (bits == 16) { 373 name = XMSS_SHA2_256_W16_H16_NAME; 374 } else if (bits == 20) { 375 name = XMSS_SHA2_256_W16_H20_NAME; 376 } else { 377 name = XMSS_DEFAULT_NAME; 378 } 379 if ((r = sshkey_xmss_init(k, name)) != 0 || 380 (r = sshkey_xmss_init_bds_state(k)) != 0 || 381 (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0) 382 return r; 383 if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL || 384 (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) { 385 return SSH_ERR_ALLOC_FAIL; 386 } 387 xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k), 388 sshkey_xmss_params(k)); 389 return 0; 390} 391 392int 393sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename, 394 int *have_file, int printerror) 395{ 396 struct sshbuf *b = NULL, *enc = NULL; 397 int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1; 398 u_int32_t len; 399 unsigned char buf[4], *data = NULL; 400 401 *have_file = 0; 402 if ((fd = open(filename, O_RDONLY)) >= 0) { 403 *have_file = 1; 404 if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) { 405 PRINT("corrupt state file: %s", filename); 406 goto done; 407 } 408 len = PEEK_U32(buf); 409 if ((data = calloc(len, 1)) == NULL) { 410 ret = SSH_ERR_ALLOC_FAIL; 411 goto done; 412 } 413 if (atomicio(read, fd, data, len) != len) { 414 PRINT("cannot read blob: %s", filename); 415 goto done; 416 } 417 if ((enc = sshbuf_from(data, len)) == NULL) { 418 ret = SSH_ERR_ALLOC_FAIL; 419 goto done; 420 } 421 sshkey_xmss_free_bds(k); 422 if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) { 423 ret = r; 424 goto done; 425 } 426 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) { 427 ret = r; 428 goto done; 429 } 430 ret = 0; 431 } 432done: 433 if (fd != -1) 434 close(fd); 435 free(data); 436 sshbuf_free(enc); 437 sshbuf_free(b); 438 return ret; 439} 440 441int 442sshkey_xmss_get_state(const struct sshkey *k, int printerror) 443{ 444 struct ssh_xmss_state *state = k->xmss_state; 445 u_int32_t idx = 0; 446 char *filename = NULL; 447 char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL; 448 int lockfd = -1, have_state = 0, have_ostate = 0, tries = 0; 449 int ret = SSH_ERR_INVALID_ARGUMENT, r; 450 451 if (state == NULL) 452 goto done; 453 /* 454 * If maxidx is set, then we are allowed a limited number 455 * of signatures, but don't need to access the disk. 456 * Otherwise we need to deal with the on-disk state. 457 */ 458 if (state->maxidx) { 459 /* xmss_sk always contains the current state */ 460 idx = PEEK_U32(k->xmss_sk); 461 if (idx < state->maxidx) { 462 state->allow_update = 1; 463 return 0; 464 } 465 return SSH_ERR_INVALID_ARGUMENT; 466 } 467 if ((filename = k->xmss_filename) == NULL) 468 goto done; 469 if (asprintf(&lockfile, "%s.lock", filename) == -1 || 470 asprintf(&statefile, "%s.state", filename) == -1 || 471 asprintf(&ostatefile, "%s.ostate", filename) == -1) { 472 ret = SSH_ERR_ALLOC_FAIL; 473 goto done; 474 } 475 if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) { 476 ret = SSH_ERR_SYSTEM_ERROR; 477 PRINT("cannot open/create: %s", lockfile); 478 goto done; 479 } 480 while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) { 481 if (errno != EWOULDBLOCK) { 482 ret = SSH_ERR_SYSTEM_ERROR; 483 PRINT("cannot lock: %s", lockfile); 484 goto done; 485 } 486 if (++tries > 10) { 487 ret = SSH_ERR_SYSTEM_ERROR; 488 PRINT("giving up on: %s", lockfile); 489 goto done; 490 } 491 usleep(1000*100*tries); 492 } 493 /* XXX no longer const */ 494 if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k), 495 statefile, &have_state, printerror)) != 0) { 496 if ((r = sshkey_xmss_get_state_from_file(__UNCONST(k), 497 ostatefile, &have_ostate, printerror)) == 0) { 498 state->allow_update = 1; 499 r = sshkey_xmss_forward_state(k, 1); 500 state->idx = PEEK_U32(k->xmss_sk); 501 state->allow_update = 0; 502 } 503 } 504 if (!have_state && !have_ostate) { 505 /* check that bds state is initialized */ 506 if (state->bds.auth == NULL) 507 goto done; 508 PRINT("start from scratch idx 0: %u", state->idx); 509 } else if (r != 0) { 510 ret = r; 511 goto done; 512 } 513 if (state->idx + 1 < state->idx) { 514 PRINT("state wrap: %u", state->idx); 515 goto done; 516 } 517 state->have_state = have_state; 518 state->lockfd = lockfd; 519 state->allow_update = 1; 520 lockfd = -1; 521 ret = 0; 522done: 523 if (lockfd != -1) 524 close(lockfd); 525 free(lockfile); 526 free(statefile); 527 free(ostatefile); 528 return ret; 529} 530 531int 532sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve) 533{ 534 struct ssh_xmss_state *state = k->xmss_state; 535 u_char *sig = NULL; 536 size_t required_siglen; 537 unsigned long long smlen; 538 u_char data; 539 int ret, r; 540 541 if (state == NULL || !state->allow_update) 542 return SSH_ERR_INVALID_ARGUMENT; 543 if (reserve == 0) 544 return SSH_ERR_INVALID_ARGUMENT; 545 if (state->idx + reserve <= state->idx) 546 return SSH_ERR_INVALID_ARGUMENT; 547 if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0) 548 return r; 549 if ((sig = malloc(required_siglen)) == NULL) 550 return SSH_ERR_ALLOC_FAIL; 551 while (reserve-- > 0) { 552 state->idx = PEEK_U32(k->xmss_sk); 553 smlen = required_siglen; 554 if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k), 555 sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) { 556 r = SSH_ERR_INVALID_ARGUMENT; 557 break; 558 } 559 } 560 free(sig); 561 return r; 562} 563 564int 565sshkey_xmss_update_state(const struct sshkey *k, int printerror) 566{ 567 struct ssh_xmss_state *state = k->xmss_state; 568 struct sshbuf *b = NULL, *enc = NULL; 569 u_int32_t idx = 0; 570 unsigned char buf[4]; 571 char *filename = NULL; 572 char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL; 573 int fd = -1; 574 int ret = SSH_ERR_INVALID_ARGUMENT; 575 576 if (state == NULL || !state->allow_update) 577 return ret; 578 if (state->maxidx) { 579 /* no update since the number of signatures is limited */ 580 ret = 0; 581 goto done; 582 } 583 idx = PEEK_U32(k->xmss_sk); 584 if (idx == state->idx) { 585 /* no signature happened, no need to update */ 586 ret = 0; 587 goto done; 588 } else if (idx != state->idx + 1) { 589 PRINT("more than one signature happened: idx %u state %u", 590 idx, state->idx); 591 goto done; 592 } 593 state->idx = idx; 594 if ((filename = k->xmss_filename) == NULL) 595 goto done; 596 if (asprintf(&statefile, "%s.state", filename) == -1 || 597 asprintf(&ostatefile, "%s.ostate", filename) == -1 || 598 asprintf(&nstatefile, "%s.nstate", filename) == -1) { 599 ret = SSH_ERR_ALLOC_FAIL; 600 goto done; 601 } 602 unlink(nstatefile); 603 if ((b = sshbuf_new()) == NULL) { 604 ret = SSH_ERR_ALLOC_FAIL; 605 goto done; 606 } 607 if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) { 608 PRINT("SERLIALIZE FAILED: %d", ret); 609 goto done; 610 } 611 if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) { 612 PRINT("ENCRYPT FAILED: %d", ret); 613 goto done; 614 } 615 if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) { 616 ret = SSH_ERR_SYSTEM_ERROR; 617 PRINT("open new state file: %s", nstatefile); 618 goto done; 619 } 620 POKE_U32(buf, sshbuf_len(enc)); 621 if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) { 622 ret = SSH_ERR_SYSTEM_ERROR; 623 PRINT("write new state file hdr: %s", nstatefile); 624 close(fd); 625 goto done; 626 } 627 if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) != 628 sshbuf_len(enc)) { 629 ret = SSH_ERR_SYSTEM_ERROR; 630 PRINT("write new state file data: %s", nstatefile); 631 close(fd); 632 goto done; 633 } 634 if (fsync(fd) == -1) { 635 ret = SSH_ERR_SYSTEM_ERROR; 636 PRINT("sync new state file: %s", nstatefile); 637 close(fd); 638 goto done; 639 } 640 if (close(fd) == -1) { 641 ret = SSH_ERR_SYSTEM_ERROR; 642 PRINT("close new state file: %s", nstatefile); 643 goto done; 644 } 645 if (state->have_state) { 646 unlink(ostatefile); 647 if (link(statefile, ostatefile)) { 648 ret = SSH_ERR_SYSTEM_ERROR; 649 PRINT("backup state %s to %s", statefile, ostatefile); 650 goto done; 651 } 652 } 653 if (rename(nstatefile, statefile) == -1) { 654 ret = SSH_ERR_SYSTEM_ERROR; 655 PRINT("rename %s to %s", nstatefile, statefile); 656 goto done; 657 } 658 ret = 0; 659done: 660 if (state->lockfd != -1) { 661 close(state->lockfd); 662 state->lockfd = -1; 663 } 664 if (nstatefile) 665 unlink(nstatefile); 666 free(statefile); 667 free(ostatefile); 668 free(nstatefile); 669 sshbuf_free(b); 670 sshbuf_free(enc); 671 return ret; 672} 673 674int 675sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b) 676{ 677 struct ssh_xmss_state *state = k->xmss_state; 678 treehash_inst *th; 679 u_int32_t i, node; 680 int r; 681 682 if (state == NULL) 683 return SSH_ERR_INVALID_ARGUMENT; 684 if (state->stack == NULL) 685 return SSH_ERR_INVALID_ARGUMENT; 686 state->stackoffset = state->bds.stackoffset; /* copy back */ 687 if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 || 688 (r = sshbuf_put_u32(b, state->idx)) != 0 || 689 (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 || 690 (r = sshbuf_put_u32(b, state->stackoffset)) != 0 || 691 (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 || 692 (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 || 693 (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 || 694 (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 || 695 (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 || 696 (r = sshbuf_put_u32(b, num_treehash(state))) != 0) 697 return r; 698 for (i = 0; i < num_treehash(state); i++) { 699 th = &state->treehash[i]; 700 node = th->node - state->th_nodes; 701 if ((r = sshbuf_put_u32(b, th->h)) != 0 || 702 (r = sshbuf_put_u32(b, th->next_idx)) != 0 || 703 (r = sshbuf_put_u32(b, th->stackusage)) != 0 || 704 (r = sshbuf_put_u8(b, th->completed)) != 0 || 705 (r = sshbuf_put_u32(b, node)) != 0) 706 return r; 707 } 708 return 0; 709} 710 711int 712sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b, 713 enum sshkey_serialize_rep opts) 714{ 715 struct ssh_xmss_state *state = k->xmss_state; 716 int r = SSH_ERR_INVALID_ARGUMENT; 717 u_char have_stack, have_filename, have_enc; 718 719 if (state == NULL) 720 return SSH_ERR_INVALID_ARGUMENT; 721 if ((r = sshbuf_put_u8(b, opts)) != 0) 722 return r; 723 switch (opts) { 724 case SSHKEY_SERIALIZE_STATE: 725 r = sshkey_xmss_serialize_state(k, b); 726 break; 727 case SSHKEY_SERIALIZE_FULL: 728 if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0) 729 return r; 730 r = sshkey_xmss_serialize_state(k, b); 731 break; 732 case SSHKEY_SERIALIZE_SHIELD: 733 /* all of stack/filename/enc are optional */ 734 have_stack = state->stack != NULL; 735 if ((r = sshbuf_put_u8(b, have_stack)) != 0) 736 return r; 737 if (have_stack) { 738 state->idx = PEEK_U32(k->xmss_sk); /* update */ 739 if ((r = sshkey_xmss_serialize_state(k, b)) != 0) 740 return r; 741 } 742 have_filename = k->xmss_filename != NULL; 743 if ((r = sshbuf_put_u8(b, have_filename)) != 0) 744 return r; 745 if (have_filename && 746 (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0) 747 return r; 748 have_enc = state->enc_keyiv != NULL; 749 if ((r = sshbuf_put_u8(b, have_enc)) != 0) 750 return r; 751 if (have_enc && 752 (r = sshkey_xmss_serialize_enc_key(k, b)) != 0) 753 return r; 754 if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 || 755 (r = sshbuf_put_u8(b, state->allow_update)) != 0) 756 return r; 757 break; 758 case SSHKEY_SERIALIZE_DEFAULT: 759 r = 0; 760 break; 761 default: 762 r = SSH_ERR_INVALID_ARGUMENT; 763 break; 764 } 765 return r; 766} 767 768int 769sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b) 770{ 771 struct ssh_xmss_state *state = k->xmss_state; 772 treehash_inst *th; 773 u_int32_t i, lh, node; 774 size_t ls, lsl, la, lk, ln, lr; 775 char *magic; 776 int r = SSH_ERR_INTERNAL_ERROR; 777 778 if (state == NULL) 779 return SSH_ERR_INVALID_ARGUMENT; 780 if (k->xmss_sk == NULL) 781 return SSH_ERR_INVALID_ARGUMENT; 782 if ((state->treehash = calloc(num_treehash(state), 783 sizeof(treehash_inst))) == NULL) 784 return SSH_ERR_ALLOC_FAIL; 785 if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 || 786 (r = sshbuf_get_u32(b, &state->idx)) != 0 || 787 (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 || 788 (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 || 789 (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 || 790 (r = sshbuf_get_string(b, &state->auth, &la)) != 0 || 791 (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 || 792 (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 || 793 (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 || 794 (r = sshbuf_get_u32(b, &lh)) != 0) 795 goto out; 796 if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) { 797 r = SSH_ERR_INVALID_ARGUMENT; 798 goto out; 799 } 800 /* XXX check stackoffset */ 801 if (ls != num_stack(state) || 802 lsl != num_stacklevels(state) || 803 la != num_auth(state) || 804 lk != num_keep(state) || 805 ln != num_th_nodes(state) || 806 lr != num_retain(state) || 807 lh != num_treehash(state)) { 808 r = SSH_ERR_INVALID_ARGUMENT; 809 goto out; 810 } 811 for (i = 0; i < num_treehash(state); i++) { 812 th = &state->treehash[i]; 813 if ((r = sshbuf_get_u32(b, &th->h)) != 0 || 814 (r = sshbuf_get_u32(b, &th->next_idx)) != 0 || 815 (r = sshbuf_get_u32(b, &th->stackusage)) != 0 || 816 (r = sshbuf_get_u8(b, &th->completed)) != 0 || 817 (r = sshbuf_get_u32(b, &node)) != 0) 818 goto out; 819 if (node < num_th_nodes(state)) 820 th->node = &state->th_nodes[node]; 821 } 822 POKE_U32(k->xmss_sk, state->idx); 823 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset, 824 state->stacklevels, state->auth, state->keep, state->treehash, 825 state->retain, 0); 826 /* success */ 827 r = 0; 828 out: 829 free(magic); 830 return r; 831} 832 833int 834sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b) 835{ 836 struct ssh_xmss_state *state = k->xmss_state; 837 enum sshkey_serialize_rep opts; 838 u_char have_state, have_stack, have_filename, have_enc; 839 int r; 840 841 if ((r = sshbuf_get_u8(b, &have_state)) != 0) 842 return r; 843 844 opts = have_state; 845 switch (opts) { 846 case SSHKEY_SERIALIZE_DEFAULT: 847 r = 0; 848 break; 849 case SSHKEY_SERIALIZE_SHIELD: 850 if ((r = sshbuf_get_u8(b, &have_stack)) != 0) 851 return r; 852 if (have_stack && 853 (r = sshkey_xmss_deserialize_state(k, b)) != 0) 854 return r; 855 if ((r = sshbuf_get_u8(b, &have_filename)) != 0) 856 return r; 857 if (have_filename && 858 (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0) 859 return r; 860 if ((r = sshbuf_get_u8(b, &have_enc)) != 0) 861 return r; 862 if (have_enc && 863 (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0) 864 return r; 865 if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 || 866 (r = sshbuf_get_u8(b, &state->allow_update)) != 0) 867 return r; 868 break; 869 case SSHKEY_SERIALIZE_STATE: 870 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) 871 return r; 872 break; 873 case SSHKEY_SERIALIZE_FULL: 874 if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 || 875 (r = sshkey_xmss_deserialize_state(k, b)) != 0) 876 return r; 877 break; 878 default: 879 r = SSH_ERR_INVALID_FORMAT; 880 break; 881 } 882 return r; 883} 884 885int 886sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b, 887 struct sshbuf **retp) 888{ 889 struct ssh_xmss_state *state = k->xmss_state; 890 struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL; 891 struct sshcipher_ctx *ciphercontext = NULL; 892 const struct sshcipher *cipher; 893 u_char *cp, *key, *iv = NULL; 894 size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen; 895 int r = SSH_ERR_INTERNAL_ERROR; 896 897 if (retp != NULL) 898 *retp = NULL; 899 if (state == NULL || 900 state->enc_keyiv == NULL || 901 state->enc_ciphername == NULL) 902 return SSH_ERR_INTERNAL_ERROR; 903 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) { 904 r = SSH_ERR_INTERNAL_ERROR; 905 goto out; 906 } 907 blocksize = cipher_blocksize(cipher); 908 keylen = cipher_keylen(cipher); 909 ivlen = cipher_ivlen(cipher); 910 authlen = cipher_authlen(cipher); 911 if (state->enc_keyiv_len != keylen + ivlen) { 912 r = SSH_ERR_INVALID_FORMAT; 913 goto out; 914 } 915 key = state->enc_keyiv; 916 if ((encrypted = sshbuf_new()) == NULL || 917 (encoded = sshbuf_new()) == NULL || 918 (padded = sshbuf_new()) == NULL || 919 (iv = malloc(ivlen)) == NULL) { 920 r = SSH_ERR_ALLOC_FAIL; 921 goto out; 922 } 923 924 /* replace first 4 bytes of IV with index to ensure uniqueness */ 925 memcpy(iv, key + keylen, ivlen); 926 POKE_U32(iv, state->idx); 927 928 if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 || 929 (r = sshbuf_put_u32(encoded, state->idx)) != 0) 930 goto out; 931 932 /* padded state will be encrypted */ 933 if ((r = sshbuf_putb(padded, b)) != 0) 934 goto out; 935 i = 0; 936 while (sshbuf_len(padded) % blocksize) { 937 if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0) 938 goto out; 939 } 940 encrypted_len = sshbuf_len(padded); 941 942 /* header including the length of state is used as AAD */ 943 if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0) 944 goto out; 945 aadlen = sshbuf_len(encoded); 946 947 /* concat header and state */ 948 if ((r = sshbuf_putb(encoded, padded)) != 0) 949 goto out; 950 951 /* reserve space for encryption of encoded data plus auth tag */ 952 /* encrypt at offset addlen */ 953 if ((r = sshbuf_reserve(encrypted, 954 encrypted_len + aadlen + authlen, &cp)) != 0 || 955 (r = cipher_init(&ciphercontext, cipher, key, keylen, 956 iv, ivlen, 1)) != 0 || 957 (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded), 958 encrypted_len, aadlen, authlen)) != 0) 959 goto out; 960 961 /* success */ 962 r = 0; 963 out: 964 if (retp != NULL) { 965 *retp = encrypted; 966 encrypted = NULL; 967 } 968 sshbuf_free(padded); 969 sshbuf_free(encoded); 970 sshbuf_free(encrypted); 971 cipher_free(ciphercontext); 972 free(iv); 973 return r; 974} 975 976int 977sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded, 978 struct sshbuf **retp) 979{ 980 struct ssh_xmss_state *state = k->xmss_state; 981 struct sshbuf *copy = NULL, *decrypted = NULL; 982 struct sshcipher_ctx *ciphercontext = NULL; 983 const struct sshcipher *cipher = NULL; 984 u_char *key, *iv = NULL, *dp; 985 size_t keylen, ivlen, authlen, aadlen; 986 u_int blocksize, encrypted_len, index; 987 int r = SSH_ERR_INTERNAL_ERROR; 988 989 if (retp != NULL) 990 *retp = NULL; 991 if (state == NULL || 992 state->enc_keyiv == NULL || 993 state->enc_ciphername == NULL) 994 return SSH_ERR_INTERNAL_ERROR; 995 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) { 996 r = SSH_ERR_INVALID_FORMAT; 997 goto out; 998 } 999 blocksize = cipher_blocksize(cipher); 1000 keylen = cipher_keylen(cipher); 1001 ivlen = cipher_ivlen(cipher); 1002 authlen = cipher_authlen(cipher); 1003 if (state->enc_keyiv_len != keylen + ivlen) { 1004 r = SSH_ERR_INTERNAL_ERROR; 1005 goto out; 1006 } 1007 key = state->enc_keyiv; 1008 1009 if ((copy = sshbuf_fromb(encoded)) == NULL || 1010 (decrypted = sshbuf_new()) == NULL || 1011 (iv = malloc(ivlen)) == NULL) { 1012 r = SSH_ERR_ALLOC_FAIL; 1013 goto out; 1014 } 1015 1016 /* check magic */ 1017 if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) || 1018 memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) { 1019 r = SSH_ERR_INVALID_FORMAT; 1020 goto out; 1021 } 1022 /* parse public portion */ 1023 if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 || 1024 (r = sshbuf_get_u32(encoded, &index)) != 0 || 1025 (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0) 1026 goto out; 1027 1028 /* check size of encrypted key blob */ 1029 if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) { 1030 r = SSH_ERR_INVALID_FORMAT; 1031 goto out; 1032 } 1033 /* check that an appropriate amount of auth data is present */ 1034 if (sshbuf_len(encoded) < authlen || 1035 sshbuf_len(encoded) - authlen < encrypted_len) { 1036 r = SSH_ERR_INVALID_FORMAT; 1037 goto out; 1038 } 1039 1040 aadlen = sshbuf_len(copy) - sshbuf_len(encoded); 1041 1042 /* replace first 4 bytes of IV with index to ensure uniqueness */ 1043 memcpy(iv, key + keylen, ivlen); 1044 POKE_U32(iv, index); 1045 1046 /* decrypt private state of key */ 1047 if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 || 1048 (r = cipher_init(&ciphercontext, cipher, key, keylen, 1049 iv, ivlen, 0)) != 0 || 1050 (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy), 1051 encrypted_len, aadlen, authlen)) != 0) 1052 goto out; 1053 1054 /* there should be no trailing data */ 1055 if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0) 1056 goto out; 1057 if (sshbuf_len(encoded) != 0) { 1058 r = SSH_ERR_INVALID_FORMAT; 1059 goto out; 1060 } 1061 1062 /* remove AAD */ 1063 if ((r = sshbuf_consume(decrypted, aadlen)) != 0) 1064 goto out; 1065 /* XXX encrypted includes unchecked padding */ 1066 1067 /* success */ 1068 r = 0; 1069 if (retp != NULL) { 1070 *retp = decrypted; 1071 decrypted = NULL; 1072 } 1073 out: 1074 cipher_free(ciphercontext); 1075 sshbuf_free(copy); 1076 sshbuf_free(decrypted); 1077 free(iv); 1078 return r; 1079} 1080 1081u_int32_t 1082sshkey_xmss_signatures_left(const struct sshkey *k) 1083{ 1084 struct ssh_xmss_state *state = k->xmss_state; 1085 u_int32_t idx; 1086 1087 if (sshkey_type_plain(k->type) == KEY_XMSS && state && 1088 state->maxidx) { 1089 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx; 1090 if (idx < state->maxidx) 1091 return state->maxidx - idx; 1092 } 1093 return 0; 1094} 1095 1096int 1097sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign) 1098{ 1099 struct ssh_xmss_state *state = k->xmss_state; 1100 1101 if (sshkey_type_plain(k->type) != KEY_XMSS) 1102 return SSH_ERR_INVALID_ARGUMENT; 1103 if (maxsign == 0) 1104 return 0; 1105 if (state->idx + maxsign < state->idx) 1106 return SSH_ERR_INVALID_ARGUMENT; 1107 state->maxidx = state->idx + maxsign; 1108 return 0; 1109} 1110