1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4 *
5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 */
7
8#include <asm/neon.h>
9#include <asm/hwcap.h>
10#include <asm/simd.h>
11#include <crypto/aes.h>
12#include <crypto/ctr.h>
13#include <crypto/sha2.h>
14#include <crypto/internal/hash.h>
15#include <crypto/internal/simd.h>
16#include <crypto/internal/skcipher.h>
17#include <crypto/scatterwalk.h>
18#include <linux/module.h>
19#include <linux/cpufeature.h>
20#include <crypto/xts.h>
21
22#include "aes-ce-setkey.h"
23
24#ifdef USE_V8_CRYPTO_EXTENSIONS
25#define MODE			"ce"
26#define PRIO			300
27#define aes_expandkey		ce_aes_expandkey
28#define aes_ecb_encrypt		ce_aes_ecb_encrypt
29#define aes_ecb_decrypt		ce_aes_ecb_decrypt
30#define aes_cbc_encrypt		ce_aes_cbc_encrypt
31#define aes_cbc_decrypt		ce_aes_cbc_decrypt
32#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
33#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
34#define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
35#define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
36#define aes_ctr_encrypt		ce_aes_ctr_encrypt
37#define aes_xctr_encrypt	ce_aes_xctr_encrypt
38#define aes_xts_encrypt		ce_aes_xts_encrypt
39#define aes_xts_decrypt		ce_aes_xts_decrypt
40#define aes_mac_update		ce_aes_mac_update
41MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
42#else
43#define MODE			"neon"
44#define PRIO			200
45#define aes_ecb_encrypt		neon_aes_ecb_encrypt
46#define aes_ecb_decrypt		neon_aes_ecb_decrypt
47#define aes_cbc_encrypt		neon_aes_cbc_encrypt
48#define aes_cbc_decrypt		neon_aes_cbc_decrypt
49#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
50#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
51#define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
52#define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
53#define aes_ctr_encrypt		neon_aes_ctr_encrypt
54#define aes_xctr_encrypt	neon_aes_xctr_encrypt
55#define aes_xts_encrypt		neon_aes_xts_encrypt
56#define aes_xts_decrypt		neon_aes_xts_decrypt
57#define aes_mac_update		neon_aes_mac_update
58MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
59#endif
60#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
61MODULE_ALIAS_CRYPTO("ecb(aes)");
62MODULE_ALIAS_CRYPTO("cbc(aes)");
63MODULE_ALIAS_CRYPTO("ctr(aes)");
64MODULE_ALIAS_CRYPTO("xts(aes)");
65MODULE_ALIAS_CRYPTO("xctr(aes)");
66#endif
67MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
68MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
69MODULE_ALIAS_CRYPTO("cmac(aes)");
70MODULE_ALIAS_CRYPTO("xcbc(aes)");
71MODULE_ALIAS_CRYPTO("cbcmac(aes)");
72
73MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
74MODULE_LICENSE("GPL v2");
75
76/* defined in aes-modes.S */
77asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
78				int rounds, int blocks);
79asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
80				int rounds, int blocks);
81
82asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
83				int rounds, int blocks, u8 iv[]);
84asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
85				int rounds, int blocks, u8 iv[]);
86
87asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
88				int rounds, int bytes, u8 const iv[]);
89asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
90				int rounds, int bytes, u8 const iv[]);
91
92asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
93				int rounds, int bytes, u8 ctr[]);
94
95asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
96				 int rounds, int bytes, u8 ctr[], int byte_ctr);
97
98asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
99				int rounds, int bytes, u32 const rk2[], u8 iv[],
100				int first);
101asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
102				int rounds, int bytes, u32 const rk2[], u8 iv[],
103				int first);
104
105asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
106				      int rounds, int blocks, u8 iv[],
107				      u32 const rk2[]);
108asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
109				      int rounds, int blocks, u8 iv[],
110				      u32 const rk2[]);
111
112asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
113			      int blocks, u8 dg[], int enc_before,
114			      int enc_after);
115
116struct crypto_aes_xts_ctx {
117	struct crypto_aes_ctx key1;
118	struct crypto_aes_ctx __aligned(8) key2;
119};
120
121struct crypto_aes_essiv_cbc_ctx {
122	struct crypto_aes_ctx key1;
123	struct crypto_aes_ctx __aligned(8) key2;
124	struct crypto_shash *hash;
125};
126
127struct mac_tfm_ctx {
128	struct crypto_aes_ctx key;
129	u8 __aligned(8) consts[];
130};
131
132struct mac_desc_ctx {
133	unsigned int len;
134	u8 dg[AES_BLOCK_SIZE];
135};
136
137static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
138			       unsigned int key_len)
139{
140	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
141
142	return aes_expandkey(ctx, in_key, key_len);
143}
144
145static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
146				      const u8 *in_key, unsigned int key_len)
147{
148	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
149	int ret;
150
151	ret = xts_verify_key(tfm, in_key, key_len);
152	if (ret)
153		return ret;
154
155	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
156	if (!ret)
157		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
158				    key_len / 2);
159	return ret;
160}
161
162static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
163					    const u8 *in_key,
164					    unsigned int key_len)
165{
166	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
167	u8 digest[SHA256_DIGEST_SIZE];
168	int ret;
169
170	ret = aes_expandkey(&ctx->key1, in_key, key_len);
171	if (ret)
172		return ret;
173
174	crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
175
176	return aes_expandkey(&ctx->key2, digest, sizeof(digest));
177}
178
179static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
180{
181	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
182	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
183	int err, rounds = 6 + ctx->key_length / 4;
184	struct skcipher_walk walk;
185	unsigned int blocks;
186
187	err = skcipher_walk_virt(&walk, req, false);
188
189	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
190		kernel_neon_begin();
191		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
192				ctx->key_enc, rounds, blocks);
193		kernel_neon_end();
194		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
195	}
196	return err;
197}
198
199static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
200{
201	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
202	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
203	int err, rounds = 6 + ctx->key_length / 4;
204	struct skcipher_walk walk;
205	unsigned int blocks;
206
207	err = skcipher_walk_virt(&walk, req, false);
208
209	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
210		kernel_neon_begin();
211		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
212				ctx->key_dec, rounds, blocks);
213		kernel_neon_end();
214		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
215	}
216	return err;
217}
218
219static int cbc_encrypt_walk(struct skcipher_request *req,
220			    struct skcipher_walk *walk)
221{
222	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
223	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
224	int err = 0, rounds = 6 + ctx->key_length / 4;
225	unsigned int blocks;
226
227	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
228		kernel_neon_begin();
229		aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
230				ctx->key_enc, rounds, blocks, walk->iv);
231		kernel_neon_end();
232		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
233	}
234	return err;
235}
236
237static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
238{
239	struct skcipher_walk walk;
240	int err;
241
242	err = skcipher_walk_virt(&walk, req, false);
243	if (err)
244		return err;
245	return cbc_encrypt_walk(req, &walk);
246}
247
248static int cbc_decrypt_walk(struct skcipher_request *req,
249			    struct skcipher_walk *walk)
250{
251	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
252	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
253	int err = 0, rounds = 6 + ctx->key_length / 4;
254	unsigned int blocks;
255
256	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
257		kernel_neon_begin();
258		aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
259				ctx->key_dec, rounds, blocks, walk->iv);
260		kernel_neon_end();
261		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
262	}
263	return err;
264}
265
266static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
267{
268	struct skcipher_walk walk;
269	int err;
270
271	err = skcipher_walk_virt(&walk, req, false);
272	if (err)
273		return err;
274	return cbc_decrypt_walk(req, &walk);
275}
276
277static int cts_cbc_encrypt(struct skcipher_request *req)
278{
279	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
280	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
281	int err, rounds = 6 + ctx->key_length / 4;
282	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
283	struct scatterlist *src = req->src, *dst = req->dst;
284	struct scatterlist sg_src[2], sg_dst[2];
285	struct skcipher_request subreq;
286	struct skcipher_walk walk;
287
288	skcipher_request_set_tfm(&subreq, tfm);
289	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
290				      NULL, NULL);
291
292	if (req->cryptlen <= AES_BLOCK_SIZE) {
293		if (req->cryptlen < AES_BLOCK_SIZE)
294			return -EINVAL;
295		cbc_blocks = 1;
296	}
297
298	if (cbc_blocks > 0) {
299		skcipher_request_set_crypt(&subreq, req->src, req->dst,
300					   cbc_blocks * AES_BLOCK_SIZE,
301					   req->iv);
302
303		err = skcipher_walk_virt(&walk, &subreq, false) ?:
304		      cbc_encrypt_walk(&subreq, &walk);
305		if (err)
306			return err;
307
308		if (req->cryptlen == AES_BLOCK_SIZE)
309			return 0;
310
311		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
312		if (req->dst != req->src)
313			dst = scatterwalk_ffwd(sg_dst, req->dst,
314					       subreq.cryptlen);
315	}
316
317	/* handle ciphertext stealing */
318	skcipher_request_set_crypt(&subreq, src, dst,
319				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
320				   req->iv);
321
322	err = skcipher_walk_virt(&walk, &subreq, false);
323	if (err)
324		return err;
325
326	kernel_neon_begin();
327	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
328			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
329	kernel_neon_end();
330
331	return skcipher_walk_done(&walk, 0);
332}
333
334static int cts_cbc_decrypt(struct skcipher_request *req)
335{
336	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
337	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
338	int err, rounds = 6 + ctx->key_length / 4;
339	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
340	struct scatterlist *src = req->src, *dst = req->dst;
341	struct scatterlist sg_src[2], sg_dst[2];
342	struct skcipher_request subreq;
343	struct skcipher_walk walk;
344
345	skcipher_request_set_tfm(&subreq, tfm);
346	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
347				      NULL, NULL);
348
349	if (req->cryptlen <= AES_BLOCK_SIZE) {
350		if (req->cryptlen < AES_BLOCK_SIZE)
351			return -EINVAL;
352		cbc_blocks = 1;
353	}
354
355	if (cbc_blocks > 0) {
356		skcipher_request_set_crypt(&subreq, req->src, req->dst,
357					   cbc_blocks * AES_BLOCK_SIZE,
358					   req->iv);
359
360		err = skcipher_walk_virt(&walk, &subreq, false) ?:
361		      cbc_decrypt_walk(&subreq, &walk);
362		if (err)
363			return err;
364
365		if (req->cryptlen == AES_BLOCK_SIZE)
366			return 0;
367
368		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
369		if (req->dst != req->src)
370			dst = scatterwalk_ffwd(sg_dst, req->dst,
371					       subreq.cryptlen);
372	}
373
374	/* handle ciphertext stealing */
375	skcipher_request_set_crypt(&subreq, src, dst,
376				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
377				   req->iv);
378
379	err = skcipher_walk_virt(&walk, &subreq, false);
380	if (err)
381		return err;
382
383	kernel_neon_begin();
384	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
385			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
386	kernel_neon_end();
387
388	return skcipher_walk_done(&walk, 0);
389}
390
391static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
392{
393	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
394
395	ctx->hash = crypto_alloc_shash("sha256", 0, 0);
396
397	return PTR_ERR_OR_ZERO(ctx->hash);
398}
399
400static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
401{
402	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
403
404	crypto_free_shash(ctx->hash);
405}
406
407static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
408{
409	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
410	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
411	int err, rounds = 6 + ctx->key1.key_length / 4;
412	struct skcipher_walk walk;
413	unsigned int blocks;
414
415	err = skcipher_walk_virt(&walk, req, false);
416
417	blocks = walk.nbytes / AES_BLOCK_SIZE;
418	if (blocks) {
419		kernel_neon_begin();
420		aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
421				      ctx->key1.key_enc, rounds, blocks,
422				      req->iv, ctx->key2.key_enc);
423		kernel_neon_end();
424		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
425	}
426	return err ?: cbc_encrypt_walk(req, &walk);
427}
428
429static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
430{
431	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
432	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
433	int err, rounds = 6 + ctx->key1.key_length / 4;
434	struct skcipher_walk walk;
435	unsigned int blocks;
436
437	err = skcipher_walk_virt(&walk, req, false);
438
439	blocks = walk.nbytes / AES_BLOCK_SIZE;
440	if (blocks) {
441		kernel_neon_begin();
442		aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
443				      ctx->key1.key_dec, rounds, blocks,
444				      req->iv, ctx->key2.key_enc);
445		kernel_neon_end();
446		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
447	}
448	return err ?: cbc_decrypt_walk(req, &walk);
449}
450
451static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
452{
453	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
454	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
455	int err, rounds = 6 + ctx->key_length / 4;
456	struct skcipher_walk walk;
457	unsigned int byte_ctr = 0;
458
459	err = skcipher_walk_virt(&walk, req, false);
460
461	while (walk.nbytes > 0) {
462		const u8 *src = walk.src.virt.addr;
463		unsigned int nbytes = walk.nbytes;
464		u8 *dst = walk.dst.virt.addr;
465		u8 buf[AES_BLOCK_SIZE];
466
467		/*
468		 * If given less than 16 bytes, we must copy the partial block
469		 * into a temporary buffer of 16 bytes to avoid out of bounds
470		 * reads and writes.  Furthermore, this code is somewhat unusual
471		 * in that it expects the end of the data to be at the end of
472		 * the temporary buffer, rather than the start of the data at
473		 * the start of the temporary buffer.
474		 */
475		if (unlikely(nbytes < AES_BLOCK_SIZE))
476			src = dst = memcpy(buf + sizeof(buf) - nbytes,
477					   src, nbytes);
478		else if (nbytes < walk.total)
479			nbytes &= ~(AES_BLOCK_SIZE - 1);
480
481		kernel_neon_begin();
482		aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
483						 walk.iv, byte_ctr);
484		kernel_neon_end();
485
486		if (unlikely(nbytes < AES_BLOCK_SIZE))
487			memcpy(walk.dst.virt.addr,
488			       buf + sizeof(buf) - nbytes, nbytes);
489		byte_ctr += nbytes;
490
491		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
492	}
493
494	return err;
495}
496
497static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
498{
499	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
500	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
501	int err, rounds = 6 + ctx->key_length / 4;
502	struct skcipher_walk walk;
503
504	err = skcipher_walk_virt(&walk, req, false);
505
506	while (walk.nbytes > 0) {
507		const u8 *src = walk.src.virt.addr;
508		unsigned int nbytes = walk.nbytes;
509		u8 *dst = walk.dst.virt.addr;
510		u8 buf[AES_BLOCK_SIZE];
511
512		/*
513		 * If given less than 16 bytes, we must copy the partial block
514		 * into a temporary buffer of 16 bytes to avoid out of bounds
515		 * reads and writes.  Furthermore, this code is somewhat unusual
516		 * in that it expects the end of the data to be at the end of
517		 * the temporary buffer, rather than the start of the data at
518		 * the start of the temporary buffer.
519		 */
520		if (unlikely(nbytes < AES_BLOCK_SIZE))
521			src = dst = memcpy(buf + sizeof(buf) - nbytes,
522					   src, nbytes);
523		else if (nbytes < walk.total)
524			nbytes &= ~(AES_BLOCK_SIZE - 1);
525
526		kernel_neon_begin();
527		aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
528				walk.iv);
529		kernel_neon_end();
530
531		if (unlikely(nbytes < AES_BLOCK_SIZE))
532			memcpy(walk.dst.virt.addr,
533			       buf + sizeof(buf) - nbytes, nbytes);
534
535		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
536	}
537
538	return err;
539}
540
541static int __maybe_unused xts_encrypt(struct skcipher_request *req)
542{
543	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
544	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
545	int err, first, rounds = 6 + ctx->key1.key_length / 4;
546	int tail = req->cryptlen % AES_BLOCK_SIZE;
547	struct scatterlist sg_src[2], sg_dst[2];
548	struct skcipher_request subreq;
549	struct scatterlist *src, *dst;
550	struct skcipher_walk walk;
551
552	if (req->cryptlen < AES_BLOCK_SIZE)
553		return -EINVAL;
554
555	err = skcipher_walk_virt(&walk, req, false);
556
557	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
558		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
559					      AES_BLOCK_SIZE) - 2;
560
561		skcipher_walk_abort(&walk);
562
563		skcipher_request_set_tfm(&subreq, tfm);
564		skcipher_request_set_callback(&subreq,
565					      skcipher_request_flags(req),
566					      NULL, NULL);
567		skcipher_request_set_crypt(&subreq, req->src, req->dst,
568					   xts_blocks * AES_BLOCK_SIZE,
569					   req->iv);
570		req = &subreq;
571		err = skcipher_walk_virt(&walk, req, false);
572	} else {
573		tail = 0;
574	}
575
576	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
577		int nbytes = walk.nbytes;
578
579		if (walk.nbytes < walk.total)
580			nbytes &= ~(AES_BLOCK_SIZE - 1);
581
582		kernel_neon_begin();
583		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
584				ctx->key1.key_enc, rounds, nbytes,
585				ctx->key2.key_enc, walk.iv, first);
586		kernel_neon_end();
587		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
588	}
589
590	if (err || likely(!tail))
591		return err;
592
593	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
594	if (req->dst != req->src)
595		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
596
597	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
598				   req->iv);
599
600	err = skcipher_walk_virt(&walk, &subreq, false);
601	if (err)
602		return err;
603
604	kernel_neon_begin();
605	aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
606			ctx->key1.key_enc, rounds, walk.nbytes,
607			ctx->key2.key_enc, walk.iv, first);
608	kernel_neon_end();
609
610	return skcipher_walk_done(&walk, 0);
611}
612
613static int __maybe_unused xts_decrypt(struct skcipher_request *req)
614{
615	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
616	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
617	int err, first, rounds = 6 + ctx->key1.key_length / 4;
618	int tail = req->cryptlen % AES_BLOCK_SIZE;
619	struct scatterlist sg_src[2], sg_dst[2];
620	struct skcipher_request subreq;
621	struct scatterlist *src, *dst;
622	struct skcipher_walk walk;
623
624	if (req->cryptlen < AES_BLOCK_SIZE)
625		return -EINVAL;
626
627	err = skcipher_walk_virt(&walk, req, false);
628
629	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
630		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
631					      AES_BLOCK_SIZE) - 2;
632
633		skcipher_walk_abort(&walk);
634
635		skcipher_request_set_tfm(&subreq, tfm);
636		skcipher_request_set_callback(&subreq,
637					      skcipher_request_flags(req),
638					      NULL, NULL);
639		skcipher_request_set_crypt(&subreq, req->src, req->dst,
640					   xts_blocks * AES_BLOCK_SIZE,
641					   req->iv);
642		req = &subreq;
643		err = skcipher_walk_virt(&walk, req, false);
644	} else {
645		tail = 0;
646	}
647
648	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
649		int nbytes = walk.nbytes;
650
651		if (walk.nbytes < walk.total)
652			nbytes &= ~(AES_BLOCK_SIZE - 1);
653
654		kernel_neon_begin();
655		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
656				ctx->key1.key_dec, rounds, nbytes,
657				ctx->key2.key_enc, walk.iv, first);
658		kernel_neon_end();
659		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
660	}
661
662	if (err || likely(!tail))
663		return err;
664
665	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
666	if (req->dst != req->src)
667		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
668
669	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
670				   req->iv);
671
672	err = skcipher_walk_virt(&walk, &subreq, false);
673	if (err)
674		return err;
675
676
677	kernel_neon_begin();
678	aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
679			ctx->key1.key_dec, rounds, walk.nbytes,
680			ctx->key2.key_enc, walk.iv, first);
681	kernel_neon_end();
682
683	return skcipher_walk_done(&walk, 0);
684}
685
686static struct skcipher_alg aes_algs[] = { {
687#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
688	.base = {
689		.cra_name		= "ecb(aes)",
690		.cra_driver_name	= "ecb-aes-" MODE,
691		.cra_priority		= PRIO,
692		.cra_blocksize		= AES_BLOCK_SIZE,
693		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
694		.cra_module		= THIS_MODULE,
695	},
696	.min_keysize	= AES_MIN_KEY_SIZE,
697	.max_keysize	= AES_MAX_KEY_SIZE,
698	.setkey		= skcipher_aes_setkey,
699	.encrypt	= ecb_encrypt,
700	.decrypt	= ecb_decrypt,
701}, {
702	.base = {
703		.cra_name		= "cbc(aes)",
704		.cra_driver_name	= "cbc-aes-" MODE,
705		.cra_priority		= PRIO,
706		.cra_blocksize		= AES_BLOCK_SIZE,
707		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
708		.cra_module		= THIS_MODULE,
709	},
710	.min_keysize	= AES_MIN_KEY_SIZE,
711	.max_keysize	= AES_MAX_KEY_SIZE,
712	.ivsize		= AES_BLOCK_SIZE,
713	.setkey		= skcipher_aes_setkey,
714	.encrypt	= cbc_encrypt,
715	.decrypt	= cbc_decrypt,
716}, {
717	.base = {
718		.cra_name		= "ctr(aes)",
719		.cra_driver_name	= "ctr-aes-" MODE,
720		.cra_priority		= PRIO,
721		.cra_blocksize		= 1,
722		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
723		.cra_module		= THIS_MODULE,
724	},
725	.min_keysize	= AES_MIN_KEY_SIZE,
726	.max_keysize	= AES_MAX_KEY_SIZE,
727	.ivsize		= AES_BLOCK_SIZE,
728	.chunksize	= AES_BLOCK_SIZE,
729	.setkey		= skcipher_aes_setkey,
730	.encrypt	= ctr_encrypt,
731	.decrypt	= ctr_encrypt,
732}, {
733	.base = {
734		.cra_name		= "xctr(aes)",
735		.cra_driver_name	= "xctr-aes-" MODE,
736		.cra_priority		= PRIO,
737		.cra_blocksize		= 1,
738		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
739		.cra_module		= THIS_MODULE,
740	},
741	.min_keysize	= AES_MIN_KEY_SIZE,
742	.max_keysize	= AES_MAX_KEY_SIZE,
743	.ivsize		= AES_BLOCK_SIZE,
744	.chunksize	= AES_BLOCK_SIZE,
745	.setkey		= skcipher_aes_setkey,
746	.encrypt	= xctr_encrypt,
747	.decrypt	= xctr_encrypt,
748}, {
749	.base = {
750		.cra_name		= "xts(aes)",
751		.cra_driver_name	= "xts-aes-" MODE,
752		.cra_priority		= PRIO,
753		.cra_blocksize		= AES_BLOCK_SIZE,
754		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
755		.cra_module		= THIS_MODULE,
756	},
757	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
758	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
759	.ivsize		= AES_BLOCK_SIZE,
760	.walksize	= 2 * AES_BLOCK_SIZE,
761	.setkey		= xts_set_key,
762	.encrypt	= xts_encrypt,
763	.decrypt	= xts_decrypt,
764}, {
765#endif
766	.base = {
767		.cra_name		= "cts(cbc(aes))",
768		.cra_driver_name	= "cts-cbc-aes-" MODE,
769		.cra_priority		= PRIO,
770		.cra_blocksize		= AES_BLOCK_SIZE,
771		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
772		.cra_module		= THIS_MODULE,
773	},
774	.min_keysize	= AES_MIN_KEY_SIZE,
775	.max_keysize	= AES_MAX_KEY_SIZE,
776	.ivsize		= AES_BLOCK_SIZE,
777	.walksize	= 2 * AES_BLOCK_SIZE,
778	.setkey		= skcipher_aes_setkey,
779	.encrypt	= cts_cbc_encrypt,
780	.decrypt	= cts_cbc_decrypt,
781}, {
782	.base = {
783		.cra_name		= "essiv(cbc(aes),sha256)",
784		.cra_driver_name	= "essiv-cbc-aes-sha256-" MODE,
785		.cra_priority		= PRIO + 1,
786		.cra_blocksize		= AES_BLOCK_SIZE,
787		.cra_ctxsize		= sizeof(struct crypto_aes_essiv_cbc_ctx),
788		.cra_module		= THIS_MODULE,
789	},
790	.min_keysize	= AES_MIN_KEY_SIZE,
791	.max_keysize	= AES_MAX_KEY_SIZE,
792	.ivsize		= AES_BLOCK_SIZE,
793	.setkey		= essiv_cbc_set_key,
794	.encrypt	= essiv_cbc_encrypt,
795	.decrypt	= essiv_cbc_decrypt,
796	.init		= essiv_cbc_init_tfm,
797	.exit		= essiv_cbc_exit_tfm,
798} };
799
800static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
801			 unsigned int key_len)
802{
803	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
804
805	return aes_expandkey(&ctx->key, in_key, key_len);
806}
807
808static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
809{
810	u64 a = be64_to_cpu(x->a);
811	u64 b = be64_to_cpu(x->b);
812
813	y->a = cpu_to_be64((a << 1) | (b >> 63));
814	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
815}
816
817static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
818		       unsigned int key_len)
819{
820	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
821	be128 *consts = (be128 *)ctx->consts;
822	int rounds = 6 + key_len / 4;
823	int err;
824
825	err = cbcmac_setkey(tfm, in_key, key_len);
826	if (err)
827		return err;
828
829	/* encrypt the zero vector */
830	kernel_neon_begin();
831	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
832			rounds, 1);
833	kernel_neon_end();
834
835	cmac_gf128_mul_by_x(consts, consts);
836	cmac_gf128_mul_by_x(consts + 1, consts);
837
838	return 0;
839}
840
841static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
842		       unsigned int key_len)
843{
844	static u8 const ks[3][AES_BLOCK_SIZE] = {
845		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
846		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
847		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
848	};
849
850	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
851	int rounds = 6 + key_len / 4;
852	u8 key[AES_BLOCK_SIZE];
853	int err;
854
855	err = cbcmac_setkey(tfm, in_key, key_len);
856	if (err)
857		return err;
858
859	kernel_neon_begin();
860	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
861	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
862	kernel_neon_end();
863
864	return cbcmac_setkey(tfm, key, sizeof(key));
865}
866
867static int mac_init(struct shash_desc *desc)
868{
869	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
870
871	memset(ctx->dg, 0, AES_BLOCK_SIZE);
872	ctx->len = 0;
873
874	return 0;
875}
876
877static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
878			  u8 dg[], int enc_before, int enc_after)
879{
880	int rounds = 6 + ctx->key_length / 4;
881
882	if (crypto_simd_usable()) {
883		int rem;
884
885		do {
886			kernel_neon_begin();
887			rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
888					     dg, enc_before, enc_after);
889			kernel_neon_end();
890			in += (blocks - rem) * AES_BLOCK_SIZE;
891			blocks = rem;
892			enc_before = 0;
893		} while (blocks);
894	} else {
895		if (enc_before)
896			aes_encrypt(ctx, dg, dg);
897
898		while (blocks--) {
899			crypto_xor(dg, in, AES_BLOCK_SIZE);
900			in += AES_BLOCK_SIZE;
901
902			if (blocks || enc_after)
903				aes_encrypt(ctx, dg, dg);
904		}
905	}
906}
907
908static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
909{
910	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
911	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
912
913	while (len > 0) {
914		unsigned int l;
915
916		if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
917		    (ctx->len + len) > AES_BLOCK_SIZE) {
918
919			int blocks = len / AES_BLOCK_SIZE;
920
921			len %= AES_BLOCK_SIZE;
922
923			mac_do_update(&tctx->key, p, blocks, ctx->dg,
924				      (ctx->len != 0), (len != 0));
925
926			p += blocks * AES_BLOCK_SIZE;
927
928			if (!len) {
929				ctx->len = AES_BLOCK_SIZE;
930				break;
931			}
932			ctx->len = 0;
933		}
934
935		l = min(len, AES_BLOCK_SIZE - ctx->len);
936
937		if (l <= AES_BLOCK_SIZE) {
938			crypto_xor(ctx->dg + ctx->len, p, l);
939			ctx->len += l;
940			len -= l;
941			p += l;
942		}
943	}
944
945	return 0;
946}
947
948static int cbcmac_final(struct shash_desc *desc, u8 *out)
949{
950	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
951	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
952
953	mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
954
955	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
956
957	return 0;
958}
959
960static int cmac_final(struct shash_desc *desc, u8 *out)
961{
962	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
963	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
964	u8 *consts = tctx->consts;
965
966	if (ctx->len != AES_BLOCK_SIZE) {
967		ctx->dg[ctx->len] ^= 0x80;
968		consts += AES_BLOCK_SIZE;
969	}
970
971	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
972
973	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
974
975	return 0;
976}
977
978static struct shash_alg mac_algs[] = { {
979	.base.cra_name		= "cmac(aes)",
980	.base.cra_driver_name	= "cmac-aes-" MODE,
981	.base.cra_priority	= PRIO,
982	.base.cra_blocksize	= AES_BLOCK_SIZE,
983	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
984				  2 * AES_BLOCK_SIZE,
985	.base.cra_module	= THIS_MODULE,
986
987	.digestsize		= AES_BLOCK_SIZE,
988	.init			= mac_init,
989	.update			= mac_update,
990	.final			= cmac_final,
991	.setkey			= cmac_setkey,
992	.descsize		= sizeof(struct mac_desc_ctx),
993}, {
994	.base.cra_name		= "xcbc(aes)",
995	.base.cra_driver_name	= "xcbc-aes-" MODE,
996	.base.cra_priority	= PRIO,
997	.base.cra_blocksize	= AES_BLOCK_SIZE,
998	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
999				  2 * AES_BLOCK_SIZE,
1000	.base.cra_module	= THIS_MODULE,
1001
1002	.digestsize		= AES_BLOCK_SIZE,
1003	.init			= mac_init,
1004	.update			= mac_update,
1005	.final			= cmac_final,
1006	.setkey			= xcbc_setkey,
1007	.descsize		= sizeof(struct mac_desc_ctx),
1008}, {
1009	.base.cra_name		= "cbcmac(aes)",
1010	.base.cra_driver_name	= "cbcmac-aes-" MODE,
1011	.base.cra_priority	= PRIO,
1012	.base.cra_blocksize	= 1,
1013	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
1014	.base.cra_module	= THIS_MODULE,
1015
1016	.digestsize		= AES_BLOCK_SIZE,
1017	.init			= mac_init,
1018	.update			= mac_update,
1019	.final			= cbcmac_final,
1020	.setkey			= cbcmac_setkey,
1021	.descsize		= sizeof(struct mac_desc_ctx),
1022} };
1023
1024static void aes_exit(void)
1025{
1026	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1027	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1028}
1029
1030static int __init aes_init(void)
1031{
1032	int err;
1033
1034	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1035	if (err)
1036		return err;
1037
1038	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1039	if (err)
1040		goto unregister_ciphers;
1041
1042	return 0;
1043
1044unregister_ciphers:
1045	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1046	return err;
1047}
1048
1049#ifdef USE_V8_CRYPTO_EXTENSIONS
1050module_cpu_feature_match(AES, aes_init);
1051EXPORT_SYMBOL_NS(ce_aes_mac_update, CRYPTO_INTERNAL);
1052#else
1053module_init(aes_init);
1054EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1055EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1056EXPORT_SYMBOL(neon_aes_ctr_encrypt);
1057EXPORT_SYMBOL(neon_aes_xts_encrypt);
1058EXPORT_SYMBOL(neon_aes_xts_decrypt);
1059#endif
1060module_exit(aes_exit);
1061