1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Bit sliced AES using NEON instructions
4 *
5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 */
7
8#include <asm/neon.h>
9#include <asm/simd.h>
10#include <crypto/aes.h>
11#include <crypto/ctr.h>
12#include <crypto/internal/simd.h>
13#include <crypto/internal/skcipher.h>
14#include <crypto/scatterwalk.h>
15#include <crypto/xts.h>
16#include <linux/module.h>
17
18MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19MODULE_LICENSE("GPL v2");
20
21MODULE_ALIAS_CRYPTO("ecb(aes)");
22MODULE_ALIAS_CRYPTO("cbc(aes)");
23MODULE_ALIAS_CRYPTO("ctr(aes)");
24MODULE_ALIAS_CRYPTO("xts(aes)");
25
26asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27
28asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29				  int rounds, int blocks);
30asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31				  int rounds, int blocks);
32
33asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34				  int rounds, int blocks, u8 iv[]);
35
36asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37				  int rounds, int blocks, u8 iv[]);
38
39asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40				  int rounds, int blocks, u8 iv[]);
41asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42				  int rounds, int blocks, u8 iv[]);
43
44/* borrowed from aes-neon-blk.ko */
45asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
46				     int rounds, int blocks);
47asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
48				     int rounds, int blocks, u8 iv[]);
49asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
50				     int rounds, int bytes, u8 ctr[]);
51asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
52				     u32 const rk1[], int rounds, int bytes,
53				     u32 const rk2[], u8 iv[], int first);
54asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
55				     u32 const rk1[], int rounds, int bytes,
56				     u32 const rk2[], u8 iv[], int first);
57
58struct aesbs_ctx {
59	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32];
60	int	rounds;
61} __aligned(AES_BLOCK_SIZE);
62
63struct aesbs_cbc_ctr_ctx {
64	struct aesbs_ctx	key;
65	u32			enc[AES_MAX_KEYLENGTH_U32];
66};
67
68struct aesbs_xts_ctx {
69	struct aesbs_ctx	key;
70	u32			twkey[AES_MAX_KEYLENGTH_U32];
71	struct crypto_aes_ctx	cts;
72};
73
74static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
75			unsigned int key_len)
76{
77	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
78	struct crypto_aes_ctx rk;
79	int err;
80
81	err = aes_expandkey(&rk, in_key, key_len);
82	if (err)
83		return err;
84
85	ctx->rounds = 6 + key_len / 4;
86
87	kernel_neon_begin();
88	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
89	kernel_neon_end();
90
91	return 0;
92}
93
94static int __ecb_crypt(struct skcipher_request *req,
95		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
96				  int rounds, int blocks))
97{
98	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
99	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
100	struct skcipher_walk walk;
101	int err;
102
103	err = skcipher_walk_virt(&walk, req, false);
104
105	while (walk.nbytes >= AES_BLOCK_SIZE) {
106		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
107
108		if (walk.nbytes < walk.total)
109			blocks = round_down(blocks,
110					    walk.stride / AES_BLOCK_SIZE);
111
112		kernel_neon_begin();
113		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
114		   ctx->rounds, blocks);
115		kernel_neon_end();
116		err = skcipher_walk_done(&walk,
117					 walk.nbytes - blocks * AES_BLOCK_SIZE);
118	}
119
120	return err;
121}
122
123static int ecb_encrypt(struct skcipher_request *req)
124{
125	return __ecb_crypt(req, aesbs_ecb_encrypt);
126}
127
128static int ecb_decrypt(struct skcipher_request *req)
129{
130	return __ecb_crypt(req, aesbs_ecb_decrypt);
131}
132
133static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
134			    unsigned int key_len)
135{
136	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
137	struct crypto_aes_ctx rk;
138	int err;
139
140	err = aes_expandkey(&rk, in_key, key_len);
141	if (err)
142		return err;
143
144	ctx->key.rounds = 6 + key_len / 4;
145
146	memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
147
148	kernel_neon_begin();
149	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
150	kernel_neon_end();
151	memzero_explicit(&rk, sizeof(rk));
152
153	return 0;
154}
155
156static int cbc_encrypt(struct skcipher_request *req)
157{
158	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
159	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
160	struct skcipher_walk walk;
161	int err;
162
163	err = skcipher_walk_virt(&walk, req, false);
164
165	while (walk.nbytes >= AES_BLOCK_SIZE) {
166		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
167
168		/* fall back to the non-bitsliced NEON implementation */
169		kernel_neon_begin();
170		neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
171				     ctx->enc, ctx->key.rounds, blocks,
172				     walk.iv);
173		kernel_neon_end();
174		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
175	}
176	return err;
177}
178
179static int cbc_decrypt(struct skcipher_request *req)
180{
181	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
182	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
183	struct skcipher_walk walk;
184	int err;
185
186	err = skcipher_walk_virt(&walk, req, false);
187
188	while (walk.nbytes >= AES_BLOCK_SIZE) {
189		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
190
191		if (walk.nbytes < walk.total)
192			blocks = round_down(blocks,
193					    walk.stride / AES_BLOCK_SIZE);
194
195		kernel_neon_begin();
196		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
197				  ctx->key.rk, ctx->key.rounds, blocks,
198				  walk.iv);
199		kernel_neon_end();
200		err = skcipher_walk_done(&walk,
201					 walk.nbytes - blocks * AES_BLOCK_SIZE);
202	}
203
204	return err;
205}
206
207static int ctr_encrypt(struct skcipher_request *req)
208{
209	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
210	struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
211	struct skcipher_walk walk;
212	int err;
213
214	err = skcipher_walk_virt(&walk, req, false);
215
216	while (walk.nbytes > 0) {
217		int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
218		int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE);
219		const u8 *src = walk.src.virt.addr;
220		u8 *dst = walk.dst.virt.addr;
221
222		kernel_neon_begin();
223		if (blocks >= 8) {
224			aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds,
225					  blocks, walk.iv);
226			dst += blocks * AES_BLOCK_SIZE;
227			src += blocks * AES_BLOCK_SIZE;
228		}
229		if (nbytes && walk.nbytes == walk.total) {
230			u8 buf[AES_BLOCK_SIZE];
231			u8 *d = dst;
232
233			if (unlikely(nbytes < AES_BLOCK_SIZE))
234				src = dst = memcpy(buf + sizeof(buf) - nbytes,
235						   src, nbytes);
236
237			neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds,
238					     nbytes, walk.iv);
239
240			if (unlikely(nbytes < AES_BLOCK_SIZE))
241				memcpy(d, dst, nbytes);
242
243			nbytes = 0;
244		}
245		kernel_neon_end();
246		err = skcipher_walk_done(&walk, nbytes);
247	}
248	return err;
249}
250
251static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
252			    unsigned int key_len)
253{
254	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
255	struct crypto_aes_ctx rk;
256	int err;
257
258	err = xts_verify_key(tfm, in_key, key_len);
259	if (err)
260		return err;
261
262	key_len /= 2;
263	err = aes_expandkey(&ctx->cts, in_key, key_len);
264	if (err)
265		return err;
266
267	err = aes_expandkey(&rk, in_key + key_len, key_len);
268	if (err)
269		return err;
270
271	memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
272
273	return aesbs_setkey(tfm, in_key, key_len);
274}
275
276static int __xts_crypt(struct skcipher_request *req, bool encrypt,
277		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
278				  int rounds, int blocks, u8 iv[]))
279{
280	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
281	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
282	int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
283	struct scatterlist sg_src[2], sg_dst[2];
284	struct skcipher_request subreq;
285	struct scatterlist *src, *dst;
286	struct skcipher_walk walk;
287	int nbytes, err;
288	int first = 1;
289	u8 *out, *in;
290
291	if (req->cryptlen < AES_BLOCK_SIZE)
292		return -EINVAL;
293
294	/* ensure that the cts tail is covered by a single step */
295	if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
296		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
297					      AES_BLOCK_SIZE) - 2;
298
299		skcipher_request_set_tfm(&subreq, tfm);
300		skcipher_request_set_callback(&subreq,
301					      skcipher_request_flags(req),
302					      NULL, NULL);
303		skcipher_request_set_crypt(&subreq, req->src, req->dst,
304					   xts_blocks * AES_BLOCK_SIZE,
305					   req->iv);
306		req = &subreq;
307	} else {
308		tail = 0;
309	}
310
311	err = skcipher_walk_virt(&walk, req, false);
312	if (err)
313		return err;
314
315	while (walk.nbytes >= AES_BLOCK_SIZE) {
316		int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
317		out = walk.dst.virt.addr;
318		in = walk.src.virt.addr;
319		nbytes = walk.nbytes;
320
321		kernel_neon_begin();
322		if (blocks >= 8) {
323			if (first == 1)
324				neon_aes_ecb_encrypt(walk.iv, walk.iv,
325						     ctx->twkey,
326						     ctx->key.rounds, 1);
327			first = 2;
328
329			fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
330			   walk.iv);
331
332			out += blocks * AES_BLOCK_SIZE;
333			in += blocks * AES_BLOCK_SIZE;
334			nbytes -= blocks * AES_BLOCK_SIZE;
335		}
336		if (walk.nbytes == walk.total && nbytes > 0) {
337			if (encrypt)
338				neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
339						     ctx->key.rounds, nbytes,
340						     ctx->twkey, walk.iv, first);
341			else
342				neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
343						     ctx->key.rounds, nbytes,
344						     ctx->twkey, walk.iv, first);
345			nbytes = first = 0;
346		}
347		kernel_neon_end();
348		err = skcipher_walk_done(&walk, nbytes);
349	}
350
351	if (err || likely(!tail))
352		return err;
353
354	/* handle ciphertext stealing */
355	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
356	if (req->dst != req->src)
357		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
358
359	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
360				   req->iv);
361
362	err = skcipher_walk_virt(&walk, req, false);
363	if (err)
364		return err;
365
366	out = walk.dst.virt.addr;
367	in = walk.src.virt.addr;
368	nbytes = walk.nbytes;
369
370	kernel_neon_begin();
371	if (encrypt)
372		neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
373				     nbytes, ctx->twkey, walk.iv, first);
374	else
375		neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
376				     nbytes, ctx->twkey, walk.iv, first);
377	kernel_neon_end();
378
379	return skcipher_walk_done(&walk, 0);
380}
381
382static int xts_encrypt(struct skcipher_request *req)
383{
384	return __xts_crypt(req, true, aesbs_xts_encrypt);
385}
386
387static int xts_decrypt(struct skcipher_request *req)
388{
389	return __xts_crypt(req, false, aesbs_xts_decrypt);
390}
391
392static struct skcipher_alg aes_algs[] = { {
393	.base.cra_name		= "ecb(aes)",
394	.base.cra_driver_name	= "ecb-aes-neonbs",
395	.base.cra_priority	= 250,
396	.base.cra_blocksize	= AES_BLOCK_SIZE,
397	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
398	.base.cra_module	= THIS_MODULE,
399
400	.min_keysize		= AES_MIN_KEY_SIZE,
401	.max_keysize		= AES_MAX_KEY_SIZE,
402	.walksize		= 8 * AES_BLOCK_SIZE,
403	.setkey			= aesbs_setkey,
404	.encrypt		= ecb_encrypt,
405	.decrypt		= ecb_decrypt,
406}, {
407	.base.cra_name		= "cbc(aes)",
408	.base.cra_driver_name	= "cbc-aes-neonbs",
409	.base.cra_priority	= 250,
410	.base.cra_blocksize	= AES_BLOCK_SIZE,
411	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctr_ctx),
412	.base.cra_module	= THIS_MODULE,
413
414	.min_keysize		= AES_MIN_KEY_SIZE,
415	.max_keysize		= AES_MAX_KEY_SIZE,
416	.walksize		= 8 * AES_BLOCK_SIZE,
417	.ivsize			= AES_BLOCK_SIZE,
418	.setkey			= aesbs_cbc_ctr_setkey,
419	.encrypt		= cbc_encrypt,
420	.decrypt		= cbc_decrypt,
421}, {
422	.base.cra_name		= "ctr(aes)",
423	.base.cra_driver_name	= "ctr-aes-neonbs",
424	.base.cra_priority	= 250,
425	.base.cra_blocksize	= 1,
426	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctr_ctx),
427	.base.cra_module	= THIS_MODULE,
428
429	.min_keysize		= AES_MIN_KEY_SIZE,
430	.max_keysize		= AES_MAX_KEY_SIZE,
431	.chunksize		= AES_BLOCK_SIZE,
432	.walksize		= 8 * AES_BLOCK_SIZE,
433	.ivsize			= AES_BLOCK_SIZE,
434	.setkey			= aesbs_cbc_ctr_setkey,
435	.encrypt		= ctr_encrypt,
436	.decrypt		= ctr_encrypt,
437}, {
438	.base.cra_name		= "xts(aes)",
439	.base.cra_driver_name	= "xts-aes-neonbs",
440	.base.cra_priority	= 250,
441	.base.cra_blocksize	= AES_BLOCK_SIZE,
442	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
443	.base.cra_module	= THIS_MODULE,
444
445	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
446	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
447	.walksize		= 8 * AES_BLOCK_SIZE,
448	.ivsize			= AES_BLOCK_SIZE,
449	.setkey			= aesbs_xts_setkey,
450	.encrypt		= xts_encrypt,
451	.decrypt		= xts_decrypt,
452} };
453
454static void aes_exit(void)
455{
456	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
457}
458
459static int __init aes_init(void)
460{
461	if (!cpu_have_named_feature(ASIMD))
462		return -ENODEV;
463
464	return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
465}
466
467module_init(aes_init);
468module_exit(aes_exit);
469