1/* SPDX-License-Identifier: GPL-2.0-or-later */
2/*
3 * SM4 Cipher Algorithm, using ARMv8 NEON
4 * as specified in
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6 *
7 * Copyright (C) 2022, Alibaba Group.
8 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9 */
10
11#include <linux/module.h>
12#include <linux/crypto.h>
13#include <linux/kernel.h>
14#include <linux/cpufeature.h>
15#include <asm/neon.h>
16#include <asm/simd.h>
17#include <crypto/internal/simd.h>
18#include <crypto/internal/skcipher.h>
19#include <crypto/sm4.h>
20
21asmlinkage void sm4_neon_crypt(const u32 *rkey, u8 *dst, const u8 *src,
22			       unsigned int nblocks);
23asmlinkage void sm4_neon_cbc_dec(const u32 *rkey_dec, u8 *dst, const u8 *src,
24				 u8 *iv, unsigned int nblocks);
25asmlinkage void sm4_neon_ctr_crypt(const u32 *rkey_enc, u8 *dst, const u8 *src,
26				   u8 *iv, unsigned int nblocks);
27
28static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
29		      unsigned int key_len)
30{
31	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
32
33	return sm4_expandkey(ctx, key, key_len);
34}
35
36static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
37{
38	struct skcipher_walk walk;
39	unsigned int nbytes;
40	int err;
41
42	err = skcipher_walk_virt(&walk, req, false);
43
44	while ((nbytes = walk.nbytes) > 0) {
45		const u8 *src = walk.src.virt.addr;
46		u8 *dst = walk.dst.virt.addr;
47		unsigned int nblocks;
48
49		nblocks = nbytes / SM4_BLOCK_SIZE;
50		if (nblocks) {
51			kernel_neon_begin();
52
53			sm4_neon_crypt(rkey, dst, src, nblocks);
54
55			kernel_neon_end();
56		}
57
58		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
59	}
60
61	return err;
62}
63
64static int sm4_ecb_encrypt(struct skcipher_request *req)
65{
66	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
67	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
68
69	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
70}
71
72static int sm4_ecb_decrypt(struct skcipher_request *req)
73{
74	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
75	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
76
77	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
78}
79
80static int sm4_cbc_encrypt(struct skcipher_request *req)
81{
82	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
83	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
84	struct skcipher_walk walk;
85	unsigned int nbytes;
86	int err;
87
88	err = skcipher_walk_virt(&walk, req, false);
89
90	while ((nbytes = walk.nbytes) > 0) {
91		const u8 *iv = walk.iv;
92		const u8 *src = walk.src.virt.addr;
93		u8 *dst = walk.dst.virt.addr;
94
95		while (nbytes >= SM4_BLOCK_SIZE) {
96			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
97			sm4_crypt_block(ctx->rkey_enc, dst, dst);
98			iv = dst;
99			src += SM4_BLOCK_SIZE;
100			dst += SM4_BLOCK_SIZE;
101			nbytes -= SM4_BLOCK_SIZE;
102		}
103		if (iv != walk.iv)
104			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
105
106		err = skcipher_walk_done(&walk, nbytes);
107	}
108
109	return err;
110}
111
112static int sm4_cbc_decrypt(struct skcipher_request *req)
113{
114	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
115	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
116	struct skcipher_walk walk;
117	unsigned int nbytes;
118	int err;
119
120	err = skcipher_walk_virt(&walk, req, false);
121
122	while ((nbytes = walk.nbytes) > 0) {
123		const u8 *src = walk.src.virt.addr;
124		u8 *dst = walk.dst.virt.addr;
125		unsigned int nblocks;
126
127		nblocks = nbytes / SM4_BLOCK_SIZE;
128		if (nblocks) {
129			kernel_neon_begin();
130
131			sm4_neon_cbc_dec(ctx->rkey_dec, dst, src,
132					 walk.iv, nblocks);
133
134			kernel_neon_end();
135		}
136
137		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
138	}
139
140	return err;
141}
142
143static int sm4_ctr_crypt(struct skcipher_request *req)
144{
145	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
146	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
147	struct skcipher_walk walk;
148	unsigned int nbytes;
149	int err;
150
151	err = skcipher_walk_virt(&walk, req, false);
152
153	while ((nbytes = walk.nbytes) > 0) {
154		const u8 *src = walk.src.virt.addr;
155		u8 *dst = walk.dst.virt.addr;
156		unsigned int nblocks;
157
158		nblocks = nbytes / SM4_BLOCK_SIZE;
159		if (nblocks) {
160			kernel_neon_begin();
161
162			sm4_neon_ctr_crypt(ctx->rkey_enc, dst, src,
163					   walk.iv, nblocks);
164
165			kernel_neon_end();
166
167			dst += nblocks * SM4_BLOCK_SIZE;
168			src += nblocks * SM4_BLOCK_SIZE;
169			nbytes -= nblocks * SM4_BLOCK_SIZE;
170		}
171
172		/* tail */
173		if (walk.nbytes == walk.total && nbytes > 0) {
174			u8 keystream[SM4_BLOCK_SIZE];
175
176			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
177			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
178			crypto_xor_cpy(dst, src, keystream, nbytes);
179			nbytes = 0;
180		}
181
182		err = skcipher_walk_done(&walk, nbytes);
183	}
184
185	return err;
186}
187
188static struct skcipher_alg sm4_algs[] = {
189	{
190		.base = {
191			.cra_name		= "ecb(sm4)",
192			.cra_driver_name	= "ecb-sm4-neon",
193			.cra_priority		= 200,
194			.cra_blocksize		= SM4_BLOCK_SIZE,
195			.cra_ctxsize		= sizeof(struct sm4_ctx),
196			.cra_module		= THIS_MODULE,
197		},
198		.min_keysize	= SM4_KEY_SIZE,
199		.max_keysize	= SM4_KEY_SIZE,
200		.setkey		= sm4_setkey,
201		.encrypt	= sm4_ecb_encrypt,
202		.decrypt	= sm4_ecb_decrypt,
203	}, {
204		.base = {
205			.cra_name		= "cbc(sm4)",
206			.cra_driver_name	= "cbc-sm4-neon",
207			.cra_priority		= 200,
208			.cra_blocksize		= SM4_BLOCK_SIZE,
209			.cra_ctxsize		= sizeof(struct sm4_ctx),
210			.cra_module		= THIS_MODULE,
211		},
212		.min_keysize	= SM4_KEY_SIZE,
213		.max_keysize	= SM4_KEY_SIZE,
214		.ivsize		= SM4_BLOCK_SIZE,
215		.setkey		= sm4_setkey,
216		.encrypt	= sm4_cbc_encrypt,
217		.decrypt	= sm4_cbc_decrypt,
218	}, {
219		.base = {
220			.cra_name		= "ctr(sm4)",
221			.cra_driver_name	= "ctr-sm4-neon",
222			.cra_priority		= 200,
223			.cra_blocksize		= 1,
224			.cra_ctxsize		= sizeof(struct sm4_ctx),
225			.cra_module		= THIS_MODULE,
226		},
227		.min_keysize	= SM4_KEY_SIZE,
228		.max_keysize	= SM4_KEY_SIZE,
229		.ivsize		= SM4_BLOCK_SIZE,
230		.chunksize	= SM4_BLOCK_SIZE,
231		.setkey		= sm4_setkey,
232		.encrypt	= sm4_ctr_crypt,
233		.decrypt	= sm4_ctr_crypt,
234	}
235};
236
237static int __init sm4_init(void)
238{
239	return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
240}
241
242static void __exit sm4_exit(void)
243{
244	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
245}
246
247module_init(sm4_init);
248module_exit(sm4_exit);
249
250MODULE_DESCRIPTION("SM4 ECB/CBC/CTR using ARMv8 NEON");
251MODULE_ALIAS_CRYPTO("sm4-neon");
252MODULE_ALIAS_CRYPTO("sm4");
253MODULE_ALIAS_CRYPTO("ecb(sm4)");
254MODULE_ALIAS_CRYPTO("cbc(sm4)");
255MODULE_ALIAS_CRYPTO("ctr(sm4)");
256MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
257MODULE_LICENSE("GPL v2");
258