1// SPDX-License-Identifier: GPL-2.0
2#include "bcachefs.h"
3#include "checksum.h"
4#include "errcode.h"
5#include "super.h"
6#include "super-io.h"
7
8#include <linux/crc32c.h>
9#include <linux/crypto.h>
10#include <linux/xxhash.h>
11#include <linux/key.h>
12#include <linux/random.h>
13#include <linux/scatterlist.h>
14#include <crypto/algapi.h>
15#include <crypto/chacha.h>
16#include <crypto/hash.h>
17#include <crypto/poly1305.h>
18#include <crypto/skcipher.h>
19#include <keys/user-type.h>
20
21/*
22 * bch2_checksum state is an abstraction of the checksum state calculated over different pages.
23 * it features page merging without having the checksum algorithm lose its state.
24 * for native checksum aglorithms (like crc), a default seed value will do.
25 * for hash-like algorithms, a state needs to be stored
26 */
27
28struct bch2_checksum_state {
29	union {
30		u64 seed;
31		struct xxh64_state h64state;
32	};
33	unsigned int type;
34};
35
36static void bch2_checksum_init(struct bch2_checksum_state *state)
37{
38	switch (state->type) {
39	case BCH_CSUM_none:
40	case BCH_CSUM_crc32c:
41	case BCH_CSUM_crc64:
42		state->seed = 0;
43		break;
44	case BCH_CSUM_crc32c_nonzero:
45		state->seed = U32_MAX;
46		break;
47	case BCH_CSUM_crc64_nonzero:
48		state->seed = U64_MAX;
49		break;
50	case BCH_CSUM_xxhash:
51		xxh64_reset(&state->h64state, 0);
52		break;
53	default:
54		BUG();
55	}
56}
57
58static u64 bch2_checksum_final(const struct bch2_checksum_state *state)
59{
60	switch (state->type) {
61	case BCH_CSUM_none:
62	case BCH_CSUM_crc32c:
63	case BCH_CSUM_crc64:
64		return state->seed;
65	case BCH_CSUM_crc32c_nonzero:
66		return state->seed ^ U32_MAX;
67	case BCH_CSUM_crc64_nonzero:
68		return state->seed ^ U64_MAX;
69	case BCH_CSUM_xxhash:
70		return xxh64_digest(&state->h64state);
71	default:
72		BUG();
73	}
74}
75
76static void bch2_checksum_update(struct bch2_checksum_state *state, const void *data, size_t len)
77{
78	switch (state->type) {
79	case BCH_CSUM_none:
80		return;
81	case BCH_CSUM_crc32c_nonzero:
82	case BCH_CSUM_crc32c:
83		state->seed = crc32c(state->seed, data, len);
84		break;
85	case BCH_CSUM_crc64_nonzero:
86	case BCH_CSUM_crc64:
87		state->seed = crc64_be(state->seed, data, len);
88		break;
89	case BCH_CSUM_xxhash:
90		xxh64_update(&state->h64state, data, len);
91		break;
92	default:
93		BUG();
94	}
95}
96
97static inline int do_encrypt_sg(struct crypto_sync_skcipher *tfm,
98				struct nonce nonce,
99				struct scatterlist *sg, size_t len)
100{
101	SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm);
102	int ret;
103
104	skcipher_request_set_sync_tfm(req, tfm);
105	skcipher_request_set_crypt(req, sg, sg, len, nonce.d);
106
107	ret = crypto_skcipher_encrypt(req);
108	if (ret)
109		pr_err("got error %i from crypto_skcipher_encrypt()", ret);
110
111	return ret;
112}
113
114static inline int do_encrypt(struct crypto_sync_skcipher *tfm,
115			      struct nonce nonce,
116			      void *buf, size_t len)
117{
118	if (!is_vmalloc_addr(buf)) {
119		struct scatterlist sg;
120
121		sg_init_table(&sg, 1);
122		sg_set_page(&sg,
123			    is_vmalloc_addr(buf)
124			    ? vmalloc_to_page(buf)
125			    : virt_to_page(buf),
126			    len, offset_in_page(buf));
127		return do_encrypt_sg(tfm, nonce, &sg, len);
128	} else {
129		unsigned pages = buf_pages(buf, len);
130		struct scatterlist *sg;
131		size_t orig_len = len;
132		int ret, i;
133
134		sg = kmalloc_array(pages, sizeof(*sg), GFP_KERNEL);
135		if (!sg)
136			return -BCH_ERR_ENOMEM_do_encrypt;
137
138		sg_init_table(sg, pages);
139
140		for (i = 0; i < pages; i++) {
141			unsigned offset = offset_in_page(buf);
142			unsigned pg_len = min_t(size_t, len, PAGE_SIZE - offset);
143
144			sg_set_page(sg + i, vmalloc_to_page(buf), pg_len, offset);
145			buf += pg_len;
146			len -= pg_len;
147		}
148
149		ret = do_encrypt_sg(tfm, nonce, sg, orig_len);
150		kfree(sg);
151		return ret;
152	}
153}
154
155int bch2_chacha_encrypt_key(struct bch_key *key, struct nonce nonce,
156			    void *buf, size_t len)
157{
158	struct crypto_sync_skcipher *chacha20 =
159		crypto_alloc_sync_skcipher("chacha20", 0, 0);
160	int ret;
161
162	ret = PTR_ERR_OR_ZERO(chacha20);
163	if (ret) {
164		pr_err("error requesting chacha20 cipher: %s", bch2_err_str(ret));
165		return ret;
166	}
167
168	ret = crypto_skcipher_setkey(&chacha20->base,
169				     (void *) key, sizeof(*key));
170	if (ret) {
171		pr_err("error from crypto_skcipher_setkey(): %s", bch2_err_str(ret));
172		goto err;
173	}
174
175	ret = do_encrypt(chacha20, nonce, buf, len);
176err:
177	crypto_free_sync_skcipher(chacha20);
178	return ret;
179}
180
181static int gen_poly_key(struct bch_fs *c, struct shash_desc *desc,
182			struct nonce nonce)
183{
184	u8 key[POLY1305_KEY_SIZE];
185	int ret;
186
187	nonce.d[3] ^= BCH_NONCE_POLY;
188
189	memset(key, 0, sizeof(key));
190	ret = do_encrypt(c->chacha20, nonce, key, sizeof(key));
191	if (ret)
192		return ret;
193
194	desc->tfm = c->poly1305;
195	crypto_shash_init(desc);
196	crypto_shash_update(desc, key, sizeof(key));
197	return 0;
198}
199
200struct bch_csum bch2_checksum(struct bch_fs *c, unsigned type,
201			      struct nonce nonce, const void *data, size_t len)
202{
203	switch (type) {
204	case BCH_CSUM_none:
205	case BCH_CSUM_crc32c_nonzero:
206	case BCH_CSUM_crc64_nonzero:
207	case BCH_CSUM_crc32c:
208	case BCH_CSUM_xxhash:
209	case BCH_CSUM_crc64: {
210		struct bch2_checksum_state state;
211
212		state.type = type;
213
214		bch2_checksum_init(&state);
215		bch2_checksum_update(&state, data, len);
216
217		return (struct bch_csum) { .lo = cpu_to_le64(bch2_checksum_final(&state)) };
218	}
219
220	case BCH_CSUM_chacha20_poly1305_80:
221	case BCH_CSUM_chacha20_poly1305_128: {
222		SHASH_DESC_ON_STACK(desc, c->poly1305);
223		u8 digest[POLY1305_DIGEST_SIZE];
224		struct bch_csum ret = { 0 };
225
226		gen_poly_key(c, desc, nonce);
227
228		crypto_shash_update(desc, data, len);
229		crypto_shash_final(desc, digest);
230
231		memcpy(&ret, digest, bch_crc_bytes[type]);
232		return ret;
233	}
234	default:
235		BUG();
236	}
237}
238
239int bch2_encrypt(struct bch_fs *c, unsigned type,
240		  struct nonce nonce, void *data, size_t len)
241{
242	if (!bch2_csum_type_is_encryption(type))
243		return 0;
244
245	return do_encrypt(c->chacha20, nonce, data, len);
246}
247
248static struct bch_csum __bch2_checksum_bio(struct bch_fs *c, unsigned type,
249					   struct nonce nonce, struct bio *bio,
250					   struct bvec_iter *iter)
251{
252	struct bio_vec bv;
253
254	switch (type) {
255	case BCH_CSUM_none:
256		return (struct bch_csum) { 0 };
257	case BCH_CSUM_crc32c_nonzero:
258	case BCH_CSUM_crc64_nonzero:
259	case BCH_CSUM_crc32c:
260	case BCH_CSUM_xxhash:
261	case BCH_CSUM_crc64: {
262		struct bch2_checksum_state state;
263
264		state.type = type;
265		bch2_checksum_init(&state);
266
267#ifdef CONFIG_HIGHMEM
268		__bio_for_each_segment(bv, bio, *iter, *iter) {
269			void *p = kmap_local_page(bv.bv_page) + bv.bv_offset;
270
271			bch2_checksum_update(&state, p, bv.bv_len);
272			kunmap_local(p);
273		}
274#else
275		__bio_for_each_bvec(bv, bio, *iter, *iter)
276			bch2_checksum_update(&state, page_address(bv.bv_page) + bv.bv_offset,
277				bv.bv_len);
278#endif
279		return (struct bch_csum) { .lo = cpu_to_le64(bch2_checksum_final(&state)) };
280	}
281
282	case BCH_CSUM_chacha20_poly1305_80:
283	case BCH_CSUM_chacha20_poly1305_128: {
284		SHASH_DESC_ON_STACK(desc, c->poly1305);
285		u8 digest[POLY1305_DIGEST_SIZE];
286		struct bch_csum ret = { 0 };
287
288		gen_poly_key(c, desc, nonce);
289
290#ifdef CONFIG_HIGHMEM
291		__bio_for_each_segment(bv, bio, *iter, *iter) {
292			void *p = kmap_local_page(bv.bv_page) + bv.bv_offset;
293
294			crypto_shash_update(desc, p, bv.bv_len);
295			kunmap_local(p);
296		}
297#else
298		__bio_for_each_bvec(bv, bio, *iter, *iter)
299			crypto_shash_update(desc,
300				page_address(bv.bv_page) + bv.bv_offset,
301				bv.bv_len);
302#endif
303		crypto_shash_final(desc, digest);
304
305		memcpy(&ret, digest, bch_crc_bytes[type]);
306		return ret;
307	}
308	default:
309		BUG();
310	}
311}
312
313struct bch_csum bch2_checksum_bio(struct bch_fs *c, unsigned type,
314				  struct nonce nonce, struct bio *bio)
315{
316	struct bvec_iter iter = bio->bi_iter;
317
318	return __bch2_checksum_bio(c, type, nonce, bio, &iter);
319}
320
321int __bch2_encrypt_bio(struct bch_fs *c, unsigned type,
322		     struct nonce nonce, struct bio *bio)
323{
324	struct bio_vec bv;
325	struct bvec_iter iter;
326	struct scatterlist sgl[16], *sg = sgl;
327	size_t bytes = 0;
328	int ret = 0;
329
330	if (!bch2_csum_type_is_encryption(type))
331		return 0;
332
333	sg_init_table(sgl, ARRAY_SIZE(sgl));
334
335	bio_for_each_segment(bv, bio, iter) {
336		if (sg == sgl + ARRAY_SIZE(sgl)) {
337			sg_mark_end(sg - 1);
338
339			ret = do_encrypt_sg(c->chacha20, nonce, sgl, bytes);
340			if (ret)
341				return ret;
342
343			nonce = nonce_add(nonce, bytes);
344			bytes = 0;
345
346			sg_init_table(sgl, ARRAY_SIZE(sgl));
347			sg = sgl;
348		}
349
350		sg_set_page(sg++, bv.bv_page, bv.bv_len, bv.bv_offset);
351		bytes += bv.bv_len;
352	}
353
354	sg_mark_end(sg - 1);
355	return do_encrypt_sg(c->chacha20, nonce, sgl, bytes);
356}
357
358struct bch_csum bch2_checksum_merge(unsigned type, struct bch_csum a,
359				    struct bch_csum b, size_t b_len)
360{
361	struct bch2_checksum_state state;
362
363	state.type = type;
364	bch2_checksum_init(&state);
365	state.seed = le64_to_cpu(a.lo);
366
367	BUG_ON(!bch2_checksum_mergeable(type));
368
369	while (b_len) {
370		unsigned page_len = min_t(unsigned, b_len, PAGE_SIZE);
371
372		bch2_checksum_update(&state,
373				page_address(ZERO_PAGE(0)), page_len);
374		b_len -= page_len;
375	}
376	a.lo = cpu_to_le64(bch2_checksum_final(&state));
377	a.lo ^= b.lo;
378	a.hi ^= b.hi;
379	return a;
380}
381
382int bch2_rechecksum_bio(struct bch_fs *c, struct bio *bio,
383			struct bversion version,
384			struct bch_extent_crc_unpacked crc_old,
385			struct bch_extent_crc_unpacked *crc_a,
386			struct bch_extent_crc_unpacked *crc_b,
387			unsigned len_a, unsigned len_b,
388			unsigned new_csum_type)
389{
390	struct bvec_iter iter = bio->bi_iter;
391	struct nonce nonce = extent_nonce(version, crc_old);
392	struct bch_csum merged = { 0 };
393	struct crc_split {
394		struct bch_extent_crc_unpacked	*crc;
395		unsigned			len;
396		unsigned			csum_type;
397		struct bch_csum			csum;
398	} splits[3] = {
399		{ crc_a, len_a, new_csum_type, { 0 }},
400		{ crc_b, len_b, new_csum_type, { 0 } },
401		{ NULL,	 bio_sectors(bio) - len_a - len_b, new_csum_type, { 0 } },
402	}, *i;
403	bool mergeable = crc_old.csum_type == new_csum_type &&
404		bch2_checksum_mergeable(new_csum_type);
405	unsigned crc_nonce = crc_old.nonce;
406
407	BUG_ON(len_a + len_b > bio_sectors(bio));
408	BUG_ON(crc_old.uncompressed_size != bio_sectors(bio));
409	BUG_ON(crc_is_compressed(crc_old));
410	BUG_ON(bch2_csum_type_is_encryption(crc_old.csum_type) !=
411	       bch2_csum_type_is_encryption(new_csum_type));
412
413	for (i = splits; i < splits + ARRAY_SIZE(splits); i++) {
414		iter.bi_size = i->len << 9;
415		if (mergeable || i->crc)
416			i->csum = __bch2_checksum_bio(c, i->csum_type,
417						      nonce, bio, &iter);
418		else
419			bio_advance_iter(bio, &iter, i->len << 9);
420		nonce = nonce_add(nonce, i->len << 9);
421	}
422
423	if (mergeable)
424		for (i = splits; i < splits + ARRAY_SIZE(splits); i++)
425			merged = bch2_checksum_merge(new_csum_type, merged,
426						     i->csum, i->len << 9);
427	else
428		merged = bch2_checksum_bio(c, crc_old.csum_type,
429				extent_nonce(version, crc_old), bio);
430
431	if (bch2_crc_cmp(merged, crc_old.csum) && !c->opts.no_data_io) {
432		struct printbuf buf = PRINTBUF;
433		prt_printf(&buf, "checksum error in %s() (memory corruption or bug?)\n"
434			   "expected %0llx:%0llx got %0llx:%0llx (old type ",
435			   __func__,
436			   crc_old.csum.hi,
437			   crc_old.csum.lo,
438			   merged.hi,
439			   merged.lo);
440		bch2_prt_csum_type(&buf, crc_old.csum_type);
441		prt_str(&buf, " new type ");
442		bch2_prt_csum_type(&buf, new_csum_type);
443		prt_str(&buf, ")");
444		bch_err(c, "%s", buf.buf);
445		printbuf_exit(&buf);
446		return -EIO;
447	}
448
449	for (i = splits; i < splits + ARRAY_SIZE(splits); i++) {
450		if (i->crc)
451			*i->crc = (struct bch_extent_crc_unpacked) {
452				.csum_type		= i->csum_type,
453				.compression_type	= crc_old.compression_type,
454				.compressed_size	= i->len,
455				.uncompressed_size	= i->len,
456				.offset			= 0,
457				.live_size		= i->len,
458				.nonce			= crc_nonce,
459				.csum			= i->csum,
460			};
461
462		if (bch2_csum_type_is_encryption(new_csum_type))
463			crc_nonce += i->len;
464	}
465
466	return 0;
467}
468
469/* BCH_SB_FIELD_crypt: */
470
471static int bch2_sb_crypt_validate(struct bch_sb *sb,
472				  struct bch_sb_field *f,
473				  struct printbuf *err)
474{
475	struct bch_sb_field_crypt *crypt = field_to_type(f, crypt);
476
477	if (vstruct_bytes(&crypt->field) < sizeof(*crypt)) {
478		prt_printf(err, "wrong size (got %zu should be %zu)",
479		       vstruct_bytes(&crypt->field), sizeof(*crypt));
480		return -BCH_ERR_invalid_sb_crypt;
481	}
482
483	if (BCH_CRYPT_KDF_TYPE(crypt)) {
484		prt_printf(err, "bad kdf type %llu", BCH_CRYPT_KDF_TYPE(crypt));
485		return -BCH_ERR_invalid_sb_crypt;
486	}
487
488	return 0;
489}
490
491static void bch2_sb_crypt_to_text(struct printbuf *out, struct bch_sb *sb,
492				  struct bch_sb_field *f)
493{
494	struct bch_sb_field_crypt *crypt = field_to_type(f, crypt);
495
496	prt_printf(out, "KFD:               %llu", BCH_CRYPT_KDF_TYPE(crypt));
497	prt_newline(out);
498	prt_printf(out, "scrypt n:          %llu", BCH_KDF_SCRYPT_N(crypt));
499	prt_newline(out);
500	prt_printf(out, "scrypt r:          %llu", BCH_KDF_SCRYPT_R(crypt));
501	prt_newline(out);
502	prt_printf(out, "scrypt p:          %llu", BCH_KDF_SCRYPT_P(crypt));
503	prt_newline(out);
504}
505
506const struct bch_sb_field_ops bch_sb_field_ops_crypt = {
507	.validate	= bch2_sb_crypt_validate,
508	.to_text	= bch2_sb_crypt_to_text,
509};
510
511#ifdef __KERNEL__
512static int __bch2_request_key(char *key_description, struct bch_key *key)
513{
514	struct key *keyring_key;
515	const struct user_key_payload *ukp;
516	int ret;
517
518	keyring_key = request_key(&key_type_user, key_description, NULL);
519	if (IS_ERR(keyring_key))
520		return PTR_ERR(keyring_key);
521
522	down_read(&keyring_key->sem);
523	ukp = dereference_key_locked(keyring_key);
524	if (ukp->datalen == sizeof(*key)) {
525		memcpy(key, ukp->data, ukp->datalen);
526		ret = 0;
527	} else {
528		ret = -EINVAL;
529	}
530	up_read(&keyring_key->sem);
531	key_put(keyring_key);
532
533	return ret;
534}
535#else
536#include <keyutils.h>
537
538static int __bch2_request_key(char *key_description, struct bch_key *key)
539{
540	key_serial_t key_id;
541
542	key_id = request_key("user", key_description, NULL,
543			     KEY_SPEC_SESSION_KEYRING);
544	if (key_id >= 0)
545		goto got_key;
546
547	key_id = request_key("user", key_description, NULL,
548			     KEY_SPEC_USER_KEYRING);
549	if (key_id >= 0)
550		goto got_key;
551
552	key_id = request_key("user", key_description, NULL,
553			     KEY_SPEC_USER_SESSION_KEYRING);
554	if (key_id >= 0)
555		goto got_key;
556
557	return -errno;
558got_key:
559
560	if (keyctl_read(key_id, (void *) key, sizeof(*key)) != sizeof(*key))
561		return -1;
562
563	return 0;
564}
565
566#include "crypto.h"
567#endif
568
569int bch2_request_key(struct bch_sb *sb, struct bch_key *key)
570{
571	struct printbuf key_description = PRINTBUF;
572	int ret;
573
574	prt_printf(&key_description, "bcachefs:");
575	pr_uuid(&key_description, sb->user_uuid.b);
576
577	ret = __bch2_request_key(key_description.buf, key);
578	printbuf_exit(&key_description);
579
580#ifndef __KERNEL__
581	if (ret) {
582		char *passphrase = read_passphrase("Enter passphrase: ");
583		struct bch_encrypted_key sb_key;
584
585		bch2_passphrase_check(sb, passphrase,
586				      key, &sb_key);
587		ret = 0;
588	}
589#endif
590
591	/* stash with memfd, pass memfd fd to mount */
592
593	return ret;
594}
595
596#ifndef __KERNEL__
597int bch2_revoke_key(struct bch_sb *sb)
598{
599	key_serial_t key_id;
600	struct printbuf key_description = PRINTBUF;
601
602	prt_printf(&key_description, "bcachefs:");
603	pr_uuid(&key_description, sb->user_uuid.b);
604
605	key_id = request_key("user", key_description.buf, NULL, KEY_SPEC_USER_KEYRING);
606	printbuf_exit(&key_description);
607	if (key_id < 0)
608		return errno;
609
610	keyctl_revoke(key_id);
611
612	return 0;
613}
614#endif
615
616int bch2_decrypt_sb_key(struct bch_fs *c,
617			struct bch_sb_field_crypt *crypt,
618			struct bch_key *key)
619{
620	struct bch_encrypted_key sb_key = crypt->key;
621	struct bch_key user_key;
622	int ret = 0;
623
624	/* is key encrypted? */
625	if (!bch2_key_is_encrypted(&sb_key))
626		goto out;
627
628	ret = bch2_request_key(c->disk_sb.sb, &user_key);
629	if (ret) {
630		bch_err(c, "error requesting encryption key: %s", bch2_err_str(ret));
631		goto err;
632	}
633
634	/* decrypt real key: */
635	ret = bch2_chacha_encrypt_key(&user_key, bch2_sb_key_nonce(c),
636				      &sb_key, sizeof(sb_key));
637	if (ret)
638		goto err;
639
640	if (bch2_key_is_encrypted(&sb_key)) {
641		bch_err(c, "incorrect encryption key");
642		ret = -EINVAL;
643		goto err;
644	}
645out:
646	*key = sb_key.key;
647err:
648	memzero_explicit(&sb_key, sizeof(sb_key));
649	memzero_explicit(&user_key, sizeof(user_key));
650	return ret;
651}
652
653static int bch2_alloc_ciphers(struct bch_fs *c)
654{
655	int ret;
656
657	if (!c->chacha20)
658		c->chacha20 = crypto_alloc_sync_skcipher("chacha20", 0, 0);
659	ret = PTR_ERR_OR_ZERO(c->chacha20);
660
661	if (ret) {
662		bch_err(c, "error requesting chacha20 module: %s", bch2_err_str(ret));
663		return ret;
664	}
665
666	if (!c->poly1305)
667		c->poly1305 = crypto_alloc_shash("poly1305", 0, 0);
668	ret = PTR_ERR_OR_ZERO(c->poly1305);
669
670	if (ret) {
671		bch_err(c, "error requesting poly1305 module: %s", bch2_err_str(ret));
672		return ret;
673	}
674
675	return 0;
676}
677
678int bch2_disable_encryption(struct bch_fs *c)
679{
680	struct bch_sb_field_crypt *crypt;
681	struct bch_key key;
682	int ret = -EINVAL;
683
684	mutex_lock(&c->sb_lock);
685
686	crypt = bch2_sb_field_get(c->disk_sb.sb, crypt);
687	if (!crypt)
688		goto out;
689
690	/* is key encrypted? */
691	ret = 0;
692	if (bch2_key_is_encrypted(&crypt->key))
693		goto out;
694
695	ret = bch2_decrypt_sb_key(c, crypt, &key);
696	if (ret)
697		goto out;
698
699	crypt->key.magic	= cpu_to_le64(BCH_KEY_MAGIC);
700	crypt->key.key		= key;
701
702	SET_BCH_SB_ENCRYPTION_TYPE(c->disk_sb.sb, 0);
703	bch2_write_super(c);
704out:
705	mutex_unlock(&c->sb_lock);
706
707	return ret;
708}
709
710int bch2_enable_encryption(struct bch_fs *c, bool keyed)
711{
712	struct bch_encrypted_key key;
713	struct bch_key user_key;
714	struct bch_sb_field_crypt *crypt;
715	int ret = -EINVAL;
716
717	mutex_lock(&c->sb_lock);
718
719	/* Do we already have an encryption key? */
720	if (bch2_sb_field_get(c->disk_sb.sb, crypt))
721		goto err;
722
723	ret = bch2_alloc_ciphers(c);
724	if (ret)
725		goto err;
726
727	key.magic = cpu_to_le64(BCH_KEY_MAGIC);
728	get_random_bytes(&key.key, sizeof(key.key));
729
730	if (keyed) {
731		ret = bch2_request_key(c->disk_sb.sb, &user_key);
732		if (ret) {
733			bch_err(c, "error requesting encryption key: %s", bch2_err_str(ret));
734			goto err;
735		}
736
737		ret = bch2_chacha_encrypt_key(&user_key, bch2_sb_key_nonce(c),
738					      &key, sizeof(key));
739		if (ret)
740			goto err;
741	}
742
743	ret = crypto_skcipher_setkey(&c->chacha20->base,
744			(void *) &key.key, sizeof(key.key));
745	if (ret)
746		goto err;
747
748	crypt = bch2_sb_field_resize(&c->disk_sb, crypt,
749				     sizeof(*crypt) / sizeof(u64));
750	if (!crypt) {
751		ret = -BCH_ERR_ENOSPC_sb_crypt;
752		goto err;
753	}
754
755	crypt->key = key;
756
757	/* write superblock */
758	SET_BCH_SB_ENCRYPTION_TYPE(c->disk_sb.sb, 1);
759	bch2_write_super(c);
760err:
761	mutex_unlock(&c->sb_lock);
762	memzero_explicit(&user_key, sizeof(user_key));
763	memzero_explicit(&key, sizeof(key));
764	return ret;
765}
766
767void bch2_fs_encryption_exit(struct bch_fs *c)
768{
769	if (!IS_ERR_OR_NULL(c->poly1305))
770		crypto_free_shash(c->poly1305);
771	if (!IS_ERR_OR_NULL(c->chacha20))
772		crypto_free_sync_skcipher(c->chacha20);
773	if (!IS_ERR_OR_NULL(c->sha256))
774		crypto_free_shash(c->sha256);
775}
776
777int bch2_fs_encryption_init(struct bch_fs *c)
778{
779	struct bch_sb_field_crypt *crypt;
780	struct bch_key key;
781	int ret = 0;
782
783	c->sha256 = crypto_alloc_shash("sha256", 0, 0);
784	ret = PTR_ERR_OR_ZERO(c->sha256);
785	if (ret) {
786		bch_err(c, "error requesting sha256 module: %s", bch2_err_str(ret));
787		goto out;
788	}
789
790	crypt = bch2_sb_field_get(c->disk_sb.sb, crypt);
791	if (!crypt)
792		goto out;
793
794	ret = bch2_alloc_ciphers(c);
795	if (ret)
796		goto out;
797
798	ret = bch2_decrypt_sb_key(c, crypt, &key);
799	if (ret)
800		goto out;
801
802	ret = crypto_skcipher_setkey(&c->chacha20->base,
803			(void *) &key.key, sizeof(key.key));
804	if (ret)
805		goto out;
806out:
807	memzero_explicit(&key, sizeof(key));
808	return ret;
809}
810