1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Bit sliced AES using NEON instructions
4 *
5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 */
7
8#include <asm/neon.h>
9#include <asm/simd.h>
10#include <crypto/aes.h>
11#include <crypto/ctr.h>
12#include <crypto/internal/cipher.h>
13#include <crypto/internal/simd.h>
14#include <crypto/internal/skcipher.h>
15#include <crypto/scatterwalk.h>
16#include <crypto/xts.h>
17#include <linux/module.h>
18
19MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20MODULE_LICENSE("GPL v2");
21
22MODULE_ALIAS_CRYPTO("ecb(aes)");
23MODULE_ALIAS_CRYPTO("cbc(aes)-all");
24MODULE_ALIAS_CRYPTO("ctr(aes)");
25MODULE_ALIAS_CRYPTO("xts(aes)");
26
27MODULE_IMPORT_NS(CRYPTO_INTERNAL);
28
29asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
30
31asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
32				  int rounds, int blocks);
33asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
34				  int rounds, int blocks);
35
36asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
37				  int rounds, int blocks, u8 iv[]);
38
39asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
40				  int rounds, int blocks, u8 ctr[]);
41
42asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
43				  int rounds, int blocks, u8 iv[], int);
44asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
45				  int rounds, int blocks, u8 iv[], int);
46
47struct aesbs_ctx {
48	int	rounds;
49	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
50};
51
52struct aesbs_cbc_ctx {
53	struct aesbs_ctx	key;
54	struct crypto_skcipher	*enc_tfm;
55};
56
57struct aesbs_xts_ctx {
58	struct aesbs_ctx	key;
59	struct crypto_cipher	*cts_tfm;
60	struct crypto_cipher	*tweak_tfm;
61};
62
63struct aesbs_ctr_ctx {
64	struct aesbs_ctx	key;		/* must be first member */
65	struct crypto_aes_ctx	fallback;
66};
67
68static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
69			unsigned int key_len)
70{
71	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
72	struct crypto_aes_ctx rk;
73	int err;
74
75	err = aes_expandkey(&rk, in_key, key_len);
76	if (err)
77		return err;
78
79	ctx->rounds = 6 + key_len / 4;
80
81	kernel_neon_begin();
82	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
83	kernel_neon_end();
84
85	return 0;
86}
87
88static int __ecb_crypt(struct skcipher_request *req,
89		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
90				  int rounds, int blocks))
91{
92	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
93	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
94	struct skcipher_walk walk;
95	int err;
96
97	err = skcipher_walk_virt(&walk, req, false);
98
99	while (walk.nbytes >= AES_BLOCK_SIZE) {
100		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
101
102		if (walk.nbytes < walk.total)
103			blocks = round_down(blocks,
104					    walk.stride / AES_BLOCK_SIZE);
105
106		kernel_neon_begin();
107		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
108		   ctx->rounds, blocks);
109		kernel_neon_end();
110		err = skcipher_walk_done(&walk,
111					 walk.nbytes - blocks * AES_BLOCK_SIZE);
112	}
113
114	return err;
115}
116
117static int ecb_encrypt(struct skcipher_request *req)
118{
119	return __ecb_crypt(req, aesbs_ecb_encrypt);
120}
121
122static int ecb_decrypt(struct skcipher_request *req)
123{
124	return __ecb_crypt(req, aesbs_ecb_decrypt);
125}
126
127static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
128			    unsigned int key_len)
129{
130	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
131	struct crypto_aes_ctx rk;
132	int err;
133
134	err = aes_expandkey(&rk, in_key, key_len);
135	if (err)
136		return err;
137
138	ctx->key.rounds = 6 + key_len / 4;
139
140	kernel_neon_begin();
141	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
142	kernel_neon_end();
143	memzero_explicit(&rk, sizeof(rk));
144
145	return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
146}
147
148static int cbc_encrypt(struct skcipher_request *req)
149{
150	struct skcipher_request *subreq = skcipher_request_ctx(req);
151	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
152	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
153
154	skcipher_request_set_tfm(subreq, ctx->enc_tfm);
155	skcipher_request_set_callback(subreq,
156				      skcipher_request_flags(req),
157				      NULL, NULL);
158	skcipher_request_set_crypt(subreq, req->src, req->dst,
159				   req->cryptlen, req->iv);
160
161	return crypto_skcipher_encrypt(subreq);
162}
163
164static int cbc_decrypt(struct skcipher_request *req)
165{
166	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
167	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
168	struct skcipher_walk walk;
169	int err;
170
171	err = skcipher_walk_virt(&walk, req, false);
172
173	while (walk.nbytes >= AES_BLOCK_SIZE) {
174		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
175
176		if (walk.nbytes < walk.total)
177			blocks = round_down(blocks,
178					    walk.stride / AES_BLOCK_SIZE);
179
180		kernel_neon_begin();
181		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
182				  ctx->key.rk, ctx->key.rounds, blocks,
183				  walk.iv);
184		kernel_neon_end();
185		err = skcipher_walk_done(&walk,
186					 walk.nbytes - blocks * AES_BLOCK_SIZE);
187	}
188
189	return err;
190}
191
192static int cbc_init(struct crypto_skcipher *tfm)
193{
194	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
195	unsigned int reqsize;
196
197	ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
198					     CRYPTO_ALG_NEED_FALLBACK);
199	if (IS_ERR(ctx->enc_tfm))
200		return PTR_ERR(ctx->enc_tfm);
201
202	reqsize = sizeof(struct skcipher_request);
203	reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
204	crypto_skcipher_set_reqsize(tfm, reqsize);
205
206	return 0;
207}
208
209static void cbc_exit(struct crypto_skcipher *tfm)
210{
211	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
212
213	crypto_free_skcipher(ctx->enc_tfm);
214}
215
216static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
217				 unsigned int key_len)
218{
219	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
220	int err;
221
222	err = aes_expandkey(&ctx->fallback, in_key, key_len);
223	if (err)
224		return err;
225
226	ctx->key.rounds = 6 + key_len / 4;
227
228	kernel_neon_begin();
229	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
230	kernel_neon_end();
231
232	return 0;
233}
234
235static int ctr_encrypt(struct skcipher_request *req)
236{
237	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
238	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
239	struct skcipher_walk walk;
240	u8 buf[AES_BLOCK_SIZE];
241	int err;
242
243	err = skcipher_walk_virt(&walk, req, false);
244
245	while (walk.nbytes > 0) {
246		const u8 *src = walk.src.virt.addr;
247		u8 *dst = walk.dst.virt.addr;
248		int bytes = walk.nbytes;
249
250		if (unlikely(bytes < AES_BLOCK_SIZE))
251			src = dst = memcpy(buf + sizeof(buf) - bytes,
252					   src, bytes);
253		else if (walk.nbytes < walk.total)
254			bytes &= ~(8 * AES_BLOCK_SIZE - 1);
255
256		kernel_neon_begin();
257		aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
258		kernel_neon_end();
259
260		if (unlikely(bytes < AES_BLOCK_SIZE))
261			memcpy(walk.dst.virt.addr,
262			       buf + sizeof(buf) - bytes, bytes);
263
264		err = skcipher_walk_done(&walk, walk.nbytes - bytes);
265	}
266
267	return err;
268}
269
270static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
271{
272	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
273	unsigned long flags;
274
275	/*
276	 * Temporarily disable interrupts to avoid races where
277	 * cachelines are evicted when the CPU is interrupted
278	 * to do something else.
279	 */
280	local_irq_save(flags);
281	aes_encrypt(&ctx->fallback, dst, src);
282	local_irq_restore(flags);
283}
284
285static int ctr_encrypt_sync(struct skcipher_request *req)
286{
287	if (!crypto_simd_usable())
288		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
289
290	return ctr_encrypt(req);
291}
292
293static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
294			    unsigned int key_len)
295{
296	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
297	int err;
298
299	err = xts_verify_key(tfm, in_key, key_len);
300	if (err)
301		return err;
302
303	key_len /= 2;
304	err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
305	if (err)
306		return err;
307	err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
308	if (err)
309		return err;
310
311	return aesbs_setkey(tfm, in_key, key_len);
312}
313
314static int xts_init(struct crypto_skcipher *tfm)
315{
316	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
317
318	ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
319	if (IS_ERR(ctx->cts_tfm))
320		return PTR_ERR(ctx->cts_tfm);
321
322	ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
323	if (IS_ERR(ctx->tweak_tfm))
324		crypto_free_cipher(ctx->cts_tfm);
325
326	return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
327}
328
329static void xts_exit(struct crypto_skcipher *tfm)
330{
331	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
332
333	crypto_free_cipher(ctx->tweak_tfm);
334	crypto_free_cipher(ctx->cts_tfm);
335}
336
337static int __xts_crypt(struct skcipher_request *req, bool encrypt,
338		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
339				  int rounds, int blocks, u8 iv[], int))
340{
341	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
342	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
343	int tail = req->cryptlen % AES_BLOCK_SIZE;
344	struct skcipher_request subreq;
345	u8 buf[2 * AES_BLOCK_SIZE];
346	struct skcipher_walk walk;
347	int err;
348
349	if (req->cryptlen < AES_BLOCK_SIZE)
350		return -EINVAL;
351
352	if (unlikely(tail)) {
353		skcipher_request_set_tfm(&subreq, tfm);
354		skcipher_request_set_callback(&subreq,
355					      skcipher_request_flags(req),
356					      NULL, NULL);
357		skcipher_request_set_crypt(&subreq, req->src, req->dst,
358					   req->cryptlen - tail, req->iv);
359		req = &subreq;
360	}
361
362	err = skcipher_walk_virt(&walk, req, true);
363	if (err)
364		return err;
365
366	crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
367
368	while (walk.nbytes >= AES_BLOCK_SIZE) {
369		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
370		int reorder_last_tweak = !encrypt && tail > 0;
371
372		if (walk.nbytes < walk.total) {
373			blocks = round_down(blocks,
374					    walk.stride / AES_BLOCK_SIZE);
375			reorder_last_tweak = 0;
376		}
377
378		kernel_neon_begin();
379		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
380		   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
381		kernel_neon_end();
382		err = skcipher_walk_done(&walk,
383					 walk.nbytes - blocks * AES_BLOCK_SIZE);
384	}
385
386	if (err || likely(!tail))
387		return err;
388
389	/* handle ciphertext stealing */
390	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
391				 AES_BLOCK_SIZE, 0);
392	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
393	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
394
395	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
396
397	if (encrypt)
398		crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
399	else
400		crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
401
402	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
403
404	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
405				 AES_BLOCK_SIZE + tail, 1);
406	return 0;
407}
408
409static int xts_encrypt(struct skcipher_request *req)
410{
411	return __xts_crypt(req, true, aesbs_xts_encrypt);
412}
413
414static int xts_decrypt(struct skcipher_request *req)
415{
416	return __xts_crypt(req, false, aesbs_xts_decrypt);
417}
418
419static struct skcipher_alg aes_algs[] = { {
420	.base.cra_name		= "__ecb(aes)",
421	.base.cra_driver_name	= "__ecb-aes-neonbs",
422	.base.cra_priority	= 250,
423	.base.cra_blocksize	= AES_BLOCK_SIZE,
424	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
425	.base.cra_module	= THIS_MODULE,
426	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
427
428	.min_keysize		= AES_MIN_KEY_SIZE,
429	.max_keysize		= AES_MAX_KEY_SIZE,
430	.walksize		= 8 * AES_BLOCK_SIZE,
431	.setkey			= aesbs_setkey,
432	.encrypt		= ecb_encrypt,
433	.decrypt		= ecb_decrypt,
434}, {
435	.base.cra_name		= "__cbc(aes)",
436	.base.cra_driver_name	= "__cbc-aes-neonbs",
437	.base.cra_priority	= 250,
438	.base.cra_blocksize	= AES_BLOCK_SIZE,
439	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
440	.base.cra_module	= THIS_MODULE,
441	.base.cra_flags		= CRYPTO_ALG_INTERNAL |
442				  CRYPTO_ALG_NEED_FALLBACK,
443
444	.min_keysize		= AES_MIN_KEY_SIZE,
445	.max_keysize		= AES_MAX_KEY_SIZE,
446	.walksize		= 8 * AES_BLOCK_SIZE,
447	.ivsize			= AES_BLOCK_SIZE,
448	.setkey			= aesbs_cbc_setkey,
449	.encrypt		= cbc_encrypt,
450	.decrypt		= cbc_decrypt,
451	.init			= cbc_init,
452	.exit			= cbc_exit,
453}, {
454	.base.cra_name		= "__ctr(aes)",
455	.base.cra_driver_name	= "__ctr-aes-neonbs",
456	.base.cra_priority	= 250,
457	.base.cra_blocksize	= 1,
458	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
459	.base.cra_module	= THIS_MODULE,
460	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
461
462	.min_keysize		= AES_MIN_KEY_SIZE,
463	.max_keysize		= AES_MAX_KEY_SIZE,
464	.chunksize		= AES_BLOCK_SIZE,
465	.walksize		= 8 * AES_BLOCK_SIZE,
466	.ivsize			= AES_BLOCK_SIZE,
467	.setkey			= aesbs_setkey,
468	.encrypt		= ctr_encrypt,
469	.decrypt		= ctr_encrypt,
470}, {
471	.base.cra_name		= "ctr(aes)",
472	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
473	.base.cra_priority	= 250 - 1,
474	.base.cra_blocksize	= 1,
475	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
476	.base.cra_module	= THIS_MODULE,
477
478	.min_keysize		= AES_MIN_KEY_SIZE,
479	.max_keysize		= AES_MAX_KEY_SIZE,
480	.chunksize		= AES_BLOCK_SIZE,
481	.walksize		= 8 * AES_BLOCK_SIZE,
482	.ivsize			= AES_BLOCK_SIZE,
483	.setkey			= aesbs_ctr_setkey_sync,
484	.encrypt		= ctr_encrypt_sync,
485	.decrypt		= ctr_encrypt_sync,
486}, {
487	.base.cra_name		= "__xts(aes)",
488	.base.cra_driver_name	= "__xts-aes-neonbs",
489	.base.cra_priority	= 250,
490	.base.cra_blocksize	= AES_BLOCK_SIZE,
491	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
492	.base.cra_module	= THIS_MODULE,
493	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
494
495	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
496	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
497	.walksize		= 8 * AES_BLOCK_SIZE,
498	.ivsize			= AES_BLOCK_SIZE,
499	.setkey			= aesbs_xts_setkey,
500	.encrypt		= xts_encrypt,
501	.decrypt		= xts_decrypt,
502	.init			= xts_init,
503	.exit			= xts_exit,
504} };
505
506static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
507
508static void aes_exit(void)
509{
510	int i;
511
512	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
513		if (aes_simd_algs[i])
514			simd_skcipher_free(aes_simd_algs[i]);
515
516	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
517}
518
519static int __init aes_init(void)
520{
521	struct simd_skcipher_alg *simd;
522	const char *basename;
523	const char *algname;
524	const char *drvname;
525	int err;
526	int i;
527
528	if (!(elf_hwcap & HWCAP_NEON))
529		return -ENODEV;
530
531	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
532	if (err)
533		return err;
534
535	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
536		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
537			continue;
538
539		algname = aes_algs[i].base.cra_name + 2;
540		drvname = aes_algs[i].base.cra_driver_name + 2;
541		basename = aes_algs[i].base.cra_driver_name;
542		simd = simd_skcipher_create_compat(algname, drvname, basename);
543		err = PTR_ERR(simd);
544		if (IS_ERR(simd))
545			goto unregister_simds;
546
547		aes_simd_algs[i] = simd;
548	}
549	return 0;
550
551unregister_simds:
552	aes_exit();
553	return err;
554}
555
556late_initcall(aes_init);
557module_exit(aes_exit);
558