1/* SPDX-License-Identifier: GPL-2.0-or-later */
2/*
3 * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4 * as specified in
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6 *
7 * Copyright (c) 2021, Alibaba Group.
8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9 */
10
11#include <linux/module.h>
12#include <linux/crypto.h>
13#include <linux/kernel.h>
14#include <asm/simd.h>
15#include <crypto/internal/simd.h>
16#include <crypto/internal/skcipher.h>
17#include <crypto/sm4.h>
18#include "sm4-avx.h"
19
20#define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
21
22asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23				const u8 *src, int nblocks);
24asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25				const u8 *src, int nblocks);
26asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27				const u8 *src, u8 *iv);
28asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29				const u8 *src, u8 *iv);
30
31static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
32			unsigned int key_len)
33{
34	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
35
36	return sm4_expandkey(ctx, key, key_len);
37}
38
39static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
40{
41	struct skcipher_walk walk;
42	unsigned int nbytes;
43	int err;
44
45	err = skcipher_walk_virt(&walk, req, false);
46
47	while ((nbytes = walk.nbytes) > 0) {
48		const u8 *src = walk.src.virt.addr;
49		u8 *dst = walk.dst.virt.addr;
50
51		kernel_fpu_begin();
52		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
53			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
54			dst += SM4_CRYPT8_BLOCK_SIZE;
55			src += SM4_CRYPT8_BLOCK_SIZE;
56			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
57		}
58		while (nbytes >= SM4_BLOCK_SIZE) {
59			unsigned int nblocks = min(nbytes >> 4, 4u);
60			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
61			dst += nblocks * SM4_BLOCK_SIZE;
62			src += nblocks * SM4_BLOCK_SIZE;
63			nbytes -= nblocks * SM4_BLOCK_SIZE;
64		}
65		kernel_fpu_end();
66
67		err = skcipher_walk_done(&walk, nbytes);
68	}
69
70	return err;
71}
72
73int sm4_avx_ecb_encrypt(struct skcipher_request *req)
74{
75	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
76	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
77
78	return ecb_do_crypt(req, ctx->rkey_enc);
79}
80EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
81
82int sm4_avx_ecb_decrypt(struct skcipher_request *req)
83{
84	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
85	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
86
87	return ecb_do_crypt(req, ctx->rkey_dec);
88}
89EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
90
91int sm4_cbc_encrypt(struct skcipher_request *req)
92{
93	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
94	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
95	struct skcipher_walk walk;
96	unsigned int nbytes;
97	int err;
98
99	err = skcipher_walk_virt(&walk, req, false);
100
101	while ((nbytes = walk.nbytes) > 0) {
102		const u8 *iv = walk.iv;
103		const u8 *src = walk.src.virt.addr;
104		u8 *dst = walk.dst.virt.addr;
105
106		while (nbytes >= SM4_BLOCK_SIZE) {
107			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
108			sm4_crypt_block(ctx->rkey_enc, dst, dst);
109			iv = dst;
110			src += SM4_BLOCK_SIZE;
111			dst += SM4_BLOCK_SIZE;
112			nbytes -= SM4_BLOCK_SIZE;
113		}
114		if (iv != walk.iv)
115			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
116
117		err = skcipher_walk_done(&walk, nbytes);
118	}
119
120	return err;
121}
122EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
123
124int sm4_avx_cbc_decrypt(struct skcipher_request *req,
125			unsigned int bsize, sm4_crypt_func func)
126{
127	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
128	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
129	struct skcipher_walk walk;
130	unsigned int nbytes;
131	int err;
132
133	err = skcipher_walk_virt(&walk, req, false);
134
135	while ((nbytes = walk.nbytes) > 0) {
136		const u8 *src = walk.src.virt.addr;
137		u8 *dst = walk.dst.virt.addr;
138
139		kernel_fpu_begin();
140
141		while (nbytes >= bsize) {
142			func(ctx->rkey_dec, dst, src, walk.iv);
143			dst += bsize;
144			src += bsize;
145			nbytes -= bsize;
146		}
147
148		while (nbytes >= SM4_BLOCK_SIZE) {
149			u8 keystream[SM4_BLOCK_SIZE * 8];
150			u8 iv[SM4_BLOCK_SIZE];
151			unsigned int nblocks = min(nbytes >> 4, 8u);
152			int i;
153
154			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
155						src, nblocks);
156
157			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
158			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
159			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
160
161			for (i = nblocks - 1; i > 0; i--) {
162				crypto_xor_cpy(dst, src,
163					&keystream[i * SM4_BLOCK_SIZE],
164					SM4_BLOCK_SIZE);
165				src -= SM4_BLOCK_SIZE;
166				dst -= SM4_BLOCK_SIZE;
167			}
168			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
169			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
170			dst += nblocks * SM4_BLOCK_SIZE;
171			src += (nblocks + 1) * SM4_BLOCK_SIZE;
172			nbytes -= nblocks * SM4_BLOCK_SIZE;
173		}
174
175		kernel_fpu_end();
176		err = skcipher_walk_done(&walk, nbytes);
177	}
178
179	return err;
180}
181EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
182
183static int cbc_decrypt(struct skcipher_request *req)
184{
185	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
186				sm4_aesni_avx_cbc_dec_blk8);
187}
188
189int sm4_avx_ctr_crypt(struct skcipher_request *req,
190			unsigned int bsize, sm4_crypt_func func)
191{
192	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
193	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
194	struct skcipher_walk walk;
195	unsigned int nbytes;
196	int err;
197
198	err = skcipher_walk_virt(&walk, req, false);
199
200	while ((nbytes = walk.nbytes) > 0) {
201		const u8 *src = walk.src.virt.addr;
202		u8 *dst = walk.dst.virt.addr;
203
204		kernel_fpu_begin();
205
206		while (nbytes >= bsize) {
207			func(ctx->rkey_enc, dst, src, walk.iv);
208			dst += bsize;
209			src += bsize;
210			nbytes -= bsize;
211		}
212
213		while (nbytes >= SM4_BLOCK_SIZE) {
214			u8 keystream[SM4_BLOCK_SIZE * 8];
215			unsigned int nblocks = min(nbytes >> 4, 8u);
216			int i;
217
218			for (i = 0; i < nblocks; i++) {
219				memcpy(&keystream[i * SM4_BLOCK_SIZE],
220					walk.iv, SM4_BLOCK_SIZE);
221				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
222			}
223			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
224					keystream, nblocks);
225
226			crypto_xor_cpy(dst, src, keystream,
227					nblocks * SM4_BLOCK_SIZE);
228			dst += nblocks * SM4_BLOCK_SIZE;
229			src += nblocks * SM4_BLOCK_SIZE;
230			nbytes -= nblocks * SM4_BLOCK_SIZE;
231		}
232
233		kernel_fpu_end();
234
235		/* tail */
236		if (walk.nbytes == walk.total && nbytes > 0) {
237			u8 keystream[SM4_BLOCK_SIZE];
238
239			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
240			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
241
242			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
243
244			crypto_xor_cpy(dst, src, keystream, nbytes);
245			dst += nbytes;
246			src += nbytes;
247			nbytes = 0;
248		}
249
250		err = skcipher_walk_done(&walk, nbytes);
251	}
252
253	return err;
254}
255EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
256
257static int ctr_crypt(struct skcipher_request *req)
258{
259	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
260				sm4_aesni_avx_ctr_enc_blk8);
261}
262
263static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
264	{
265		.base = {
266			.cra_name		= "__ecb(sm4)",
267			.cra_driver_name	= "__ecb-sm4-aesni-avx",
268			.cra_priority		= 400,
269			.cra_flags		= CRYPTO_ALG_INTERNAL,
270			.cra_blocksize		= SM4_BLOCK_SIZE,
271			.cra_ctxsize		= sizeof(struct sm4_ctx),
272			.cra_module		= THIS_MODULE,
273		},
274		.min_keysize	= SM4_KEY_SIZE,
275		.max_keysize	= SM4_KEY_SIZE,
276		.walksize	= 8 * SM4_BLOCK_SIZE,
277		.setkey		= sm4_skcipher_setkey,
278		.encrypt	= sm4_avx_ecb_encrypt,
279		.decrypt	= sm4_avx_ecb_decrypt,
280	}, {
281		.base = {
282			.cra_name		= "__cbc(sm4)",
283			.cra_driver_name	= "__cbc-sm4-aesni-avx",
284			.cra_priority		= 400,
285			.cra_flags		= CRYPTO_ALG_INTERNAL,
286			.cra_blocksize		= SM4_BLOCK_SIZE,
287			.cra_ctxsize		= sizeof(struct sm4_ctx),
288			.cra_module		= THIS_MODULE,
289		},
290		.min_keysize	= SM4_KEY_SIZE,
291		.max_keysize	= SM4_KEY_SIZE,
292		.ivsize		= SM4_BLOCK_SIZE,
293		.walksize	= 8 * SM4_BLOCK_SIZE,
294		.setkey		= sm4_skcipher_setkey,
295		.encrypt	= sm4_cbc_encrypt,
296		.decrypt	= cbc_decrypt,
297	}, {
298		.base = {
299			.cra_name		= "__ctr(sm4)",
300			.cra_driver_name	= "__ctr-sm4-aesni-avx",
301			.cra_priority		= 400,
302			.cra_flags		= CRYPTO_ALG_INTERNAL,
303			.cra_blocksize		= 1,
304			.cra_ctxsize		= sizeof(struct sm4_ctx),
305			.cra_module		= THIS_MODULE,
306		},
307		.min_keysize	= SM4_KEY_SIZE,
308		.max_keysize	= SM4_KEY_SIZE,
309		.ivsize		= SM4_BLOCK_SIZE,
310		.chunksize	= SM4_BLOCK_SIZE,
311		.walksize	= 8 * SM4_BLOCK_SIZE,
312		.setkey		= sm4_skcipher_setkey,
313		.encrypt	= ctr_crypt,
314		.decrypt	= ctr_crypt,
315	}
316};
317
318static struct simd_skcipher_alg *
319simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
320
321static int __init sm4_init(void)
322{
323	const char *feature_name;
324
325	if (!boot_cpu_has(X86_FEATURE_AVX) ||
326	    !boot_cpu_has(X86_FEATURE_AES) ||
327	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
328		pr_info("AVX or AES-NI instructions are not detected.\n");
329		return -ENODEV;
330	}
331
332	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
333				&feature_name)) {
334		pr_info("CPU feature '%s' is not supported.\n", feature_name);
335		return -ENODEV;
336	}
337
338	return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
339					ARRAY_SIZE(sm4_aesni_avx_skciphers),
340					simd_sm4_aesni_avx_skciphers);
341}
342
343static void __exit sm4_exit(void)
344{
345	simd_unregister_skciphers(sm4_aesni_avx_skciphers,
346					ARRAY_SIZE(sm4_aesni_avx_skciphers),
347					simd_sm4_aesni_avx_skciphers);
348}
349
350module_init(sm4_init);
351module_exit(sm4_exit);
352
353MODULE_LICENSE("GPL v2");
354MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
355MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
356MODULE_ALIAS_CRYPTO("sm4");
357MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
358