1255736Sdavidch// SPDX-License-Identifier: GPL-2.0-only
2255736Sdavidch/*
3255736Sdavidch * AES using the RISC-V vector crypto extensions.  Includes the bare block
4255736Sdavidch * cipher and the ECB, CBC, CBC-CTS, CTR, and XTS modes.
5255736Sdavidch *
6255736Sdavidch * Copyright (C) 2023 VRULL GmbH
7255736Sdavidch * Author: Heiko Stuebner <heiko.stuebner@vrull.eu>
8255736Sdavidch *
9255736Sdavidch * Copyright (C) 2023 SiFive, Inc.
10255736Sdavidch * Author: Jerry Shih <jerry.shih@sifive.com>
11255736Sdavidch *
12255736Sdavidch * Copyright 2024 Google LLC
13255736Sdavidch */
14255736Sdavidch
15255736Sdavidch#include <asm/simd.h>
16255736Sdavidch#include <asm/vector.h>
17255736Sdavidch#include <crypto/aes.h>
18255736Sdavidch#include <crypto/internal/cipher.h>
19255736Sdavidch#include <crypto/internal/simd.h>
20255736Sdavidch#include <crypto/internal/skcipher.h>
21255736Sdavidch#include <crypto/scatterwalk.h>
22255736Sdavidch#include <crypto/xts.h>
23255736Sdavidch#include <linux/linkage.h>
24255736Sdavidch#include <linux/module.h>
25255736Sdavidch
26255736Sdavidchasmlinkage void aes_encrypt_zvkned(const struct crypto_aes_ctx *key,
27255736Sdavidch				   const u8 in[AES_BLOCK_SIZE],
28255736Sdavidch				   u8 out[AES_BLOCK_SIZE]);
29255736Sdavidchasmlinkage void aes_decrypt_zvkned(const struct crypto_aes_ctx *key,
30255736Sdavidch				   const u8 in[AES_BLOCK_SIZE],
31255736Sdavidch				   u8 out[AES_BLOCK_SIZE]);
32255736Sdavidch
33255736Sdavidchasmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key,
34255736Sdavidch				       const u8 *in, u8 *out, size_t len);
35255736Sdavidchasmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key,
36255736Sdavidch				       const u8 *in, u8 *out, size_t len);
37255736Sdavidch
38255736Sdavidchasmlinkage void aes_cbc_encrypt_zvkned(const struct crypto_aes_ctx *key,
39255736Sdavidch				       const u8 *in, u8 *out, size_t len,
40255736Sdavidch				       u8 iv[AES_BLOCK_SIZE]);
41255736Sdavidchasmlinkage void aes_cbc_decrypt_zvkned(const struct crypto_aes_ctx *key,
42255736Sdavidch				       const u8 *in, u8 *out, size_t len,
43255736Sdavidch				       u8 iv[AES_BLOCK_SIZE]);
44255736Sdavidch
45255736Sdavidchasmlinkage void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
46255736Sdavidch					 const u8 *in, u8 *out, size_t len,
47255736Sdavidch					 const u8 iv[AES_BLOCK_SIZE], bool enc);
48255736Sdavidch
49255736Sdavidchasmlinkage void aes_ctr32_crypt_zvkned_zvkb(const struct crypto_aes_ctx *key,
50255736Sdavidch					    const u8 *in, u8 *out, size_t len,
51255736Sdavidch					    u8 iv[AES_BLOCK_SIZE]);
52255736Sdavidch
53255736Sdavidchasmlinkage void aes_xts_encrypt_zvkned_zvbb_zvkg(
54255736Sdavidch			const struct crypto_aes_ctx *key,
55255736Sdavidch			const u8 *in, u8 *out, size_t len,
56255736Sdavidch			u8 tweak[AES_BLOCK_SIZE]);
57255736Sdavidch
58255736Sdavidchasmlinkage void aes_xts_decrypt_zvkned_zvbb_zvkg(
59255736Sdavidch			const struct crypto_aes_ctx *key,
60255736Sdavidch			const u8 *in, u8 *out, size_t len,
61255736Sdavidch			u8 tweak[AES_BLOCK_SIZE]);
62255736Sdavidch
63255736Sdavidchstatic int riscv64_aes_setkey(struct crypto_aes_ctx *ctx,
64255736Sdavidch			      const u8 *key, unsigned int keylen)
65255736Sdavidch{
66255736Sdavidch	/*
67255736Sdavidch	 * For now we just use the generic key expansion, for these reasons:
68255736Sdavidch	 *
69255736Sdavidch	 * - zvkned's key expansion instructions don't support AES-192.
70255736Sdavidch	 *   So, non-zvkned fallback code would be needed anyway.
71255736Sdavidch	 *
72255736Sdavidch	 * - Users of AES in Linux usually don't change keys frequently.
73255736Sdavidch	 *   So, key expansion isn't performance-critical.
74255736Sdavidch	 *
75255736Sdavidch	 * - For single-block AES exposed as a "cipher" algorithm, it's
76255736Sdavidch	 *   necessary to use struct crypto_aes_ctx and initialize its 'key_dec'
77255736Sdavidch	 *   field with the round keys for the Equivalent Inverse Cipher.  This
78255736Sdavidch	 *   is because with "cipher", decryption can be requested from a
79255736Sdavidch	 *   context where the vector unit isn't usable, necessitating a
80255736Sdavidch	 *   fallback to aes_decrypt().  But, zvkned can only generate and use
81255736Sdavidch	 *   the normal round keys.  Of course, it's preferable to not have
82255736Sdavidch	 *   special code just for "cipher", as e.g. XTS also uses a
83255736Sdavidch	 *   single-block AES encryption.  It's simplest to just use
84255736Sdavidch	 *   struct crypto_aes_ctx and aes_expandkey() everywhere.
85255736Sdavidch	 */
86255736Sdavidch	return aes_expandkey(ctx, key, keylen);
87255736Sdavidch}
88255736Sdavidch
89255736Sdavidchstatic int riscv64_aes_setkey_cipher(struct crypto_tfm *tfm,
90255736Sdavidch				     const u8 *key, unsigned int keylen)
91255736Sdavidch{
92255736Sdavidch	struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
93255736Sdavidch
94255736Sdavidch	return riscv64_aes_setkey(ctx, key, keylen);
95255736Sdavidch}
96255736Sdavidch
97255736Sdavidchstatic int riscv64_aes_setkey_skcipher(struct crypto_skcipher *tfm,
98255736Sdavidch				       const u8 *key, unsigned int keylen)
99255736Sdavidch{
100255736Sdavidch	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
101255736Sdavidch
102255736Sdavidch	return riscv64_aes_setkey(ctx, key, keylen);
103255736Sdavidch}
104255736Sdavidch
105255736Sdavidch/* Bare AES, without a mode of operation */
106255736Sdavidch
107255736Sdavidchstatic void riscv64_aes_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
108255736Sdavidch{
109255736Sdavidch	const struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
110255736Sdavidch
111255736Sdavidch	if (crypto_simd_usable()) {
112255736Sdavidch		kernel_vector_begin();
113255736Sdavidch		aes_encrypt_zvkned(ctx, src, dst);
114255736Sdavidch		kernel_vector_end();
115255736Sdavidch	} else {
116255736Sdavidch		aes_encrypt(ctx, dst, src);
117255736Sdavidch	}
118255736Sdavidch}
119255736Sdavidch
120255736Sdavidchstatic void riscv64_aes_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
121255736Sdavidch{
122255736Sdavidch	const struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
123255736Sdavidch
124255736Sdavidch	if (crypto_simd_usable()) {
125255736Sdavidch		kernel_vector_begin();
126255736Sdavidch		aes_decrypt_zvkned(ctx, src, dst);
127255736Sdavidch		kernel_vector_end();
128255736Sdavidch	} else {
129255736Sdavidch		aes_decrypt(ctx, dst, src);
130255736Sdavidch	}
131255736Sdavidch}
132255736Sdavidch
133255736Sdavidch/* AES-ECB */
134255736Sdavidch
135255736Sdavidchstatic inline int riscv64_aes_ecb_crypt(struct skcipher_request *req, bool enc)
136255736Sdavidch{
137255736Sdavidch	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
138255736Sdavidch	const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
139255736Sdavidch	struct skcipher_walk walk;
140255736Sdavidch	unsigned int nbytes;
141255736Sdavidch	int err;
142255736Sdavidch
143255736Sdavidch	err = skcipher_walk_virt(&walk, req, false);
144255736Sdavidch	while ((nbytes = walk.nbytes) != 0) {
145255736Sdavidch		kernel_vector_begin();
146255736Sdavidch		if (enc)
147255736Sdavidch			aes_ecb_encrypt_zvkned(ctx, walk.src.virt.addr,
148255736Sdavidch					       walk.dst.virt.addr,
149255736Sdavidch					       nbytes & ~(AES_BLOCK_SIZE - 1));
150255736Sdavidch		else
151255736Sdavidch			aes_ecb_decrypt_zvkned(ctx, walk.src.virt.addr,
152255736Sdavidch					       walk.dst.virt.addr,
153255736Sdavidch					       nbytes & ~(AES_BLOCK_SIZE - 1));
154255736Sdavidch		kernel_vector_end();
155255736Sdavidch		err = skcipher_walk_done(&walk, nbytes & (AES_BLOCK_SIZE - 1));
156255736Sdavidch	}
157255736Sdavidch
158255736Sdavidch	return err;
159255736Sdavidch}
160255736Sdavidch
161255736Sdavidchstatic int riscv64_aes_ecb_encrypt(struct skcipher_request *req)
162255736Sdavidch{
163255736Sdavidch	return riscv64_aes_ecb_crypt(req, true);
164255736Sdavidch}
165255736Sdavidch
166255736Sdavidchstatic int riscv64_aes_ecb_decrypt(struct skcipher_request *req)
167255736Sdavidch{
168255736Sdavidch	return riscv64_aes_ecb_crypt(req, false);
169255736Sdavidch}
170255736Sdavidch
171255736Sdavidch/* AES-CBC */
172255736Sdavidch
173255736Sdavidchstatic int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
174255736Sdavidch{
175255736Sdavidch	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
176255736Sdavidch	const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
177255736Sdavidch	struct skcipher_walk walk;
178255736Sdavidch	unsigned int nbytes;
179255736Sdavidch	int err;
180255736Sdavidch
181255736Sdavidch	err = skcipher_walk_virt(&walk, req, false);
182255736Sdavidch	while ((nbytes = walk.nbytes) != 0) {
183255736Sdavidch		kernel_vector_begin();
184255736Sdavidch		if (enc)
185255736Sdavidch			aes_cbc_encrypt_zvkned(ctx, walk.src.virt.addr,
186255736Sdavidch					       walk.dst.virt.addr,
187255736Sdavidch					       nbytes & ~(AES_BLOCK_SIZE - 1),
188255736Sdavidch					       walk.iv);
189255736Sdavidch		else
190255736Sdavidch			aes_cbc_decrypt_zvkned(ctx, walk.src.virt.addr,
191255736Sdavidch					       walk.dst.virt.addr,
192255736Sdavidch					       nbytes & ~(AES_BLOCK_SIZE - 1),
193255736Sdavidch					       walk.iv);
194255736Sdavidch		kernel_vector_end();
195255736Sdavidch		err = skcipher_walk_done(&walk, nbytes & (AES_BLOCK_SIZE - 1));
196255736Sdavidch	}
197255736Sdavidch
198255736Sdavidch	return err;
199255736Sdavidch}
200255736Sdavidch
201255736Sdavidchstatic int riscv64_aes_cbc_encrypt(struct skcipher_request *req)
202255736Sdavidch{
203255736Sdavidch	return riscv64_aes_cbc_crypt(req, true);
204255736Sdavidch}
205255736Sdavidch
206255736Sdavidchstatic int riscv64_aes_cbc_decrypt(struct skcipher_request *req)
207255736Sdavidch{
208255736Sdavidch	return riscv64_aes_cbc_crypt(req, false);
209255736Sdavidch}
210255736Sdavidch
211255736Sdavidch/* AES-CBC-CTS */
212255736Sdavidch
213255736Sdavidchstatic int riscv64_aes_cbc_cts_crypt(struct skcipher_request *req, bool enc)
214255736Sdavidch{
215255736Sdavidch	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
216255736Sdavidch	const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
217255736Sdavidch	struct scatterlist sg_src[2], sg_dst[2];
218255736Sdavidch	struct skcipher_request subreq;
219255736Sdavidch	struct scatterlist *src, *dst;
220255736Sdavidch	struct skcipher_walk walk;
221255736Sdavidch	unsigned int cbc_len;
222255736Sdavidch	int err;
223255736Sdavidch
224255736Sdavidch	if (req->cryptlen < AES_BLOCK_SIZE)
225255736Sdavidch		return -EINVAL;
226255736Sdavidch
227255736Sdavidch	err = skcipher_walk_virt(&walk, req, false);
228255736Sdavidch	if (err)
229255736Sdavidch		return err;
230255736Sdavidch	/*
231255736Sdavidch	 * If the full message is available in one step, decrypt it in one call
232255736Sdavidch	 * to the CBC-CTS assembly function.  This reduces overhead, especially
233255736Sdavidch	 * on short messages.  Otherwise, fall back to doing CBC up to the last
234255736Sdavidch	 * two blocks, then invoke CTS just for the ciphertext stealing.
235255736Sdavidch	 */
236255736Sdavidch	if (unlikely(walk.nbytes != req->cryptlen)) {
237255736Sdavidch		cbc_len = round_down(req->cryptlen - AES_BLOCK_SIZE - 1,
238255736Sdavidch				     AES_BLOCK_SIZE);
239255736Sdavidch		skcipher_walk_abort(&walk);
240255736Sdavidch		skcipher_request_set_tfm(&subreq, tfm);
241255736Sdavidch		skcipher_request_set_callback(&subreq,
242255736Sdavidch					      skcipher_request_flags(req),
243255736Sdavidch					      NULL, NULL);
244255736Sdavidch		skcipher_request_set_crypt(&subreq, req->src, req->dst,
245255736Sdavidch					   cbc_len, req->iv);
246255736Sdavidch		err = riscv64_aes_cbc_crypt(&subreq, enc);
247255736Sdavidch		if (err)
248255736Sdavidch			return err;
249255736Sdavidch		dst = src = scatterwalk_ffwd(sg_src, req->src, cbc_len);
250255736Sdavidch		if (req->dst != req->src)
251255736Sdavidch			dst = scatterwalk_ffwd(sg_dst, req->dst, cbc_len);
252255736Sdavidch		skcipher_request_set_crypt(&subreq, src, dst,
253255736Sdavidch					   req->cryptlen - cbc_len, req->iv);
254255736Sdavidch		err = skcipher_walk_virt(&walk, &subreq, false);
255255736Sdavidch		if (err)
256255736Sdavidch			return err;
257255736Sdavidch	}
258255736Sdavidch	kernel_vector_begin();
259255736Sdavidch	aes_cbc_cts_crypt_zvkned(ctx, walk.src.virt.addr, walk.dst.virt.addr,
260255736Sdavidch				 walk.nbytes, req->iv, enc);
261255736Sdavidch	kernel_vector_end();
262255736Sdavidch	return skcipher_walk_done(&walk, 0);
263255736Sdavidch}
264255736Sdavidch
265255736Sdavidchstatic int riscv64_aes_cbc_cts_encrypt(struct skcipher_request *req)
266255736Sdavidch{
267255736Sdavidch	return riscv64_aes_cbc_cts_crypt(req, true);
268255736Sdavidch}
269255736Sdavidch
270255736Sdavidchstatic int riscv64_aes_cbc_cts_decrypt(struct skcipher_request *req)
271255736Sdavidch{
272255736Sdavidch	return riscv64_aes_cbc_cts_crypt(req, false);
273255736Sdavidch}
274255736Sdavidch
275255736Sdavidch/* AES-CTR */
276255736Sdavidch
277255736Sdavidchstatic int riscv64_aes_ctr_crypt(struct skcipher_request *req)
278255736Sdavidch{
279255736Sdavidch	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
280255736Sdavidch	const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
281255736Sdavidch	unsigned int nbytes, p1_nbytes;
282255736Sdavidch	struct skcipher_walk walk;
283255736Sdavidch	u32 ctr32, nblocks;
284255736Sdavidch	int err;
285255736Sdavidch
286255736Sdavidch	/* Get the low 32-bit word of the 128-bit big endian counter. */
287255736Sdavidch	ctr32 = get_unaligned_be32(req->iv + 12);
288255736Sdavidch
289255736Sdavidch	err = skcipher_walk_virt(&walk, req, false);
290255736Sdavidch	while ((nbytes = walk.nbytes) != 0) {
291255736Sdavidch		if (nbytes < walk.total) {
292255736Sdavidch			/* Not the end yet, so keep the length block-aligned. */
293255736Sdavidch			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
294255736Sdavidch			nblocks = nbytes / AES_BLOCK_SIZE;
295255736Sdavidch		} else {
296255736Sdavidch			/* It's the end, so include any final partial block. */
297255736Sdavidch			nblocks = DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE);
298255736Sdavidch		}
299255736Sdavidch		ctr32 += nblocks;
300255736Sdavidch
301255736Sdavidch		kernel_vector_begin();
302255736Sdavidch		if (ctr32 >= nblocks) {
303255736Sdavidch			/* The low 32-bit word of the counter won't overflow. */
304255736Sdavidch			aes_ctr32_crypt_zvkned_zvkb(ctx, walk.src.virt.addr,
305255736Sdavidch						    walk.dst.virt.addr, nbytes,
306255736Sdavidch						    req->iv);
307255736Sdavidch		} else {
308255736Sdavidch			/*
309255736Sdavidch			 * The low 32-bit word of the counter will overflow.
310255736Sdavidch			 * The assembly doesn't handle this case, so split the
311255736Sdavidch			 * operation into two at the point where the overflow
312255736Sdavidch			 * will occur.  After the first part, add the carry bit.
313255736Sdavidch			 */
314255736Sdavidch			p1_nbytes = min_t(unsigned int, nbytes,
315255736Sdavidch					  (nblocks - ctr32) * AES_BLOCK_SIZE);
316255736Sdavidch			aes_ctr32_crypt_zvkned_zvkb(ctx, walk.src.virt.addr,
317255736Sdavidch						    walk.dst.virt.addr,
318255736Sdavidch						    p1_nbytes, req->iv);
319255736Sdavidch			crypto_inc(req->iv, 12);
320255736Sdavidch
321255736Sdavidch			if (ctr32) {
322255736Sdavidch				aes_ctr32_crypt_zvkned_zvkb(
323255736Sdavidch					ctx,
324255736Sdavidch					walk.src.virt.addr + p1_nbytes,
325255736Sdavidch					walk.dst.virt.addr + p1_nbytes,
326255736Sdavidch					nbytes - p1_nbytes, req->iv);
327255736Sdavidch			}
328255736Sdavidch		}
329255736Sdavidch		kernel_vector_end();
330255736Sdavidch
331255736Sdavidch		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
332255736Sdavidch	}
333255736Sdavidch
334255736Sdavidch	return err;
335255736Sdavidch}
336255736Sdavidch
337255736Sdavidch/* AES-XTS */
338255736Sdavidch
339255736Sdavidchstruct riscv64_aes_xts_ctx {
340255736Sdavidch	struct crypto_aes_ctx ctx1;
341255736Sdavidch	struct crypto_aes_ctx ctx2;
342255736Sdavidch};
343255736Sdavidch
344255736Sdavidchstatic int riscv64_aes_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
345255736Sdavidch				  unsigned int keylen)
346255736Sdavidch{
347255736Sdavidch	struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
348255736Sdavidch
349255736Sdavidch	return xts_verify_key(tfm, key, keylen) ?:
350255736Sdavidch	       riscv64_aes_setkey(&ctx->ctx1, key, keylen / 2) ?:
351255736Sdavidch	       riscv64_aes_setkey(&ctx->ctx2, key + keylen / 2, keylen / 2);
352255736Sdavidch}
353255736Sdavidch
354255736Sdavidchstatic int riscv64_aes_xts_crypt(struct skcipher_request *req, bool enc)
355255736Sdavidch{
356255736Sdavidch	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
357255736Sdavidch	const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
358255736Sdavidch	int tail = req->cryptlen % AES_BLOCK_SIZE;
359255736Sdavidch	struct scatterlist sg_src[2], sg_dst[2];
360255736Sdavidch	struct skcipher_request subreq;
361255736Sdavidch	struct scatterlist *src, *dst;
362255736Sdavidch	struct skcipher_walk walk;
363255736Sdavidch	int err;
364255736Sdavidch
365255736Sdavidch	if (req->cryptlen < AES_BLOCK_SIZE)
366255736Sdavidch		return -EINVAL;
367255736Sdavidch
368255736Sdavidch	/* Encrypt the IV with the tweak key to get the first tweak. */
369255736Sdavidch	kernel_vector_begin();
370255736Sdavidch	aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv);
371255736Sdavidch	kernel_vector_end();
372255736Sdavidch
373255736Sdavidch	err = skcipher_walk_virt(&walk, req, false);
374255736Sdavidch
375255736Sdavidch	/*
376255736Sdavidch	 * If the message length isn't divisible by the AES block size and the
377255736Sdavidch	 * full message isn't available in one step of the scatterlist walk,
378255736Sdavidch	 * then separate off the last full block and the partial block.  This
379255736Sdavidch	 * ensures that they are processed in the same call to the assembly
380255736Sdavidch	 * function, which is required for ciphertext stealing.
381255736Sdavidch	 */
382255736Sdavidch	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
383255736Sdavidch		skcipher_walk_abort(&walk);
384255736Sdavidch
385255736Sdavidch		skcipher_request_set_tfm(&subreq, tfm);
386255736Sdavidch		skcipher_request_set_callback(&subreq,
387255736Sdavidch					      skcipher_request_flags(req),
388255736Sdavidch					      NULL, NULL);
389255736Sdavidch		skcipher_request_set_crypt(&subreq, req->src, req->dst,
390255736Sdavidch					   req->cryptlen - tail - AES_BLOCK_SIZE,
391255736Sdavidch					   req->iv);
392255736Sdavidch		req = &subreq;
393255736Sdavidch		err = skcipher_walk_virt(&walk, req, false);
394255736Sdavidch	} else {
395255736Sdavidch		tail = 0;
396255736Sdavidch	}
397255736Sdavidch
398255736Sdavidch	while (walk.nbytes) {
399255736Sdavidch		unsigned int nbytes = walk.nbytes;
400255736Sdavidch
401255736Sdavidch		if (nbytes < walk.total)
402255736Sdavidch			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
403255736Sdavidch
404255736Sdavidch		kernel_vector_begin();
405255736Sdavidch		if (enc)
406255736Sdavidch			aes_xts_encrypt_zvkned_zvbb_zvkg(
407255736Sdavidch				&ctx->ctx1, walk.src.virt.addr,
408255736Sdavidch				walk.dst.virt.addr, nbytes, req->iv);
409255736Sdavidch		else
410255736Sdavidch			aes_xts_decrypt_zvkned_zvbb_zvkg(
411255736Sdavidch				&ctx->ctx1, walk.src.virt.addr,
412255736Sdavidch				walk.dst.virt.addr, nbytes, req->iv);
413255736Sdavidch		kernel_vector_end();
414255736Sdavidch		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
415255736Sdavidch	}
416255736Sdavidch
417255736Sdavidch	if (err || likely(!tail))
418255736Sdavidch		return err;
419255736Sdavidch
420255736Sdavidch	/* Do ciphertext stealing with the last full block and partial block. */
421255736Sdavidch
422255736Sdavidch	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
423255736Sdavidch	if (req->dst != req->src)
424255736Sdavidch		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
425255736Sdavidch
426255736Sdavidch	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
427255736Sdavidch				   req->iv);
428255736Sdavidch
429255736Sdavidch	err = skcipher_walk_virt(&walk, req, false);
430255736Sdavidch	if (err)
431255736Sdavidch		return err;
432255736Sdavidch
433255736Sdavidch	kernel_vector_begin();
434255736Sdavidch	if (enc)
435255736Sdavidch		aes_xts_encrypt_zvkned_zvbb_zvkg(
436255736Sdavidch			&ctx->ctx1, walk.src.virt.addr,
437255736Sdavidch			walk.dst.virt.addr, walk.nbytes, req->iv);
438255736Sdavidch	else
439255736Sdavidch		aes_xts_decrypt_zvkned_zvbb_zvkg(
440255736Sdavidch			&ctx->ctx1, walk.src.virt.addr,
441255736Sdavidch			walk.dst.virt.addr, walk.nbytes, req->iv);
442255736Sdavidch	kernel_vector_end();
443255736Sdavidch
444255736Sdavidch	return skcipher_walk_done(&walk, 0);
445255736Sdavidch}
446255736Sdavidch
447255736Sdavidchstatic int riscv64_aes_xts_encrypt(struct skcipher_request *req)
448255736Sdavidch{
449255736Sdavidch	return riscv64_aes_xts_crypt(req, true);
450255736Sdavidch}
451255736Sdavidch
452255736Sdavidchstatic int riscv64_aes_xts_decrypt(struct skcipher_request *req)
453255736Sdavidch{
454255736Sdavidch	return riscv64_aes_xts_crypt(req, false);
455255736Sdavidch}
456255736Sdavidch
457255736Sdavidch/* Algorithm definitions */
458255736Sdavidch
459255736Sdavidchstatic struct crypto_alg riscv64_zvkned_aes_cipher_alg = {
460255736Sdavidch	.cra_flags = CRYPTO_ALG_TYPE_CIPHER,
461255736Sdavidch	.cra_blocksize = AES_BLOCK_SIZE,
462255736Sdavidch	.cra_ctxsize = sizeof(struct crypto_aes_ctx),
463255736Sdavidch	.cra_priority = 300,
464255736Sdavidch	.cra_name = "aes",
465255736Sdavidch	.cra_driver_name = "aes-riscv64-zvkned",
466255736Sdavidch	.cra_cipher = {
467255736Sdavidch		.cia_min_keysize = AES_MIN_KEY_SIZE,
468255736Sdavidch		.cia_max_keysize = AES_MAX_KEY_SIZE,
469255736Sdavidch		.cia_setkey = riscv64_aes_setkey_cipher,
470255736Sdavidch		.cia_encrypt = riscv64_aes_encrypt,
471255736Sdavidch		.cia_decrypt = riscv64_aes_decrypt,
472255736Sdavidch	},
473255736Sdavidch	.cra_module = THIS_MODULE,
474255736Sdavidch};
475255736Sdavidch
476255736Sdavidchstatic struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = {
477255736Sdavidch	{
478255736Sdavidch		.setkey = riscv64_aes_setkey_skcipher,
479255736Sdavidch		.encrypt = riscv64_aes_ecb_encrypt,
480255736Sdavidch		.decrypt = riscv64_aes_ecb_decrypt,
481255736Sdavidch		.min_keysize = AES_MIN_KEY_SIZE,
482255736Sdavidch		.max_keysize = AES_MAX_KEY_SIZE,
483258203Sedavis		.walksize = 8 * AES_BLOCK_SIZE, /* matches LMUL=8 */
484255736Sdavidch		.base = {
485255736Sdavidch			.cra_blocksize = AES_BLOCK_SIZE,
486255736Sdavidch			.cra_ctxsize = sizeof(struct crypto_aes_ctx),
487255736Sdavidch			.cra_priority = 300,
488255736Sdavidch			.cra_name = "ecb(aes)",
489255736Sdavidch			.cra_driver_name = "ecb-aes-riscv64-zvkned",
490255736Sdavidch			.cra_module = THIS_MODULE,
491255736Sdavidch		},
492255736Sdavidch	}, {
493255736Sdavidch		.setkey = riscv64_aes_setkey_skcipher,
494255736Sdavidch		.encrypt = riscv64_aes_cbc_encrypt,
495255736Sdavidch		.decrypt = riscv64_aes_cbc_decrypt,
496255736Sdavidch		.min_keysize = AES_MIN_KEY_SIZE,
497255736Sdavidch		.max_keysize = AES_MAX_KEY_SIZE,
498255736Sdavidch		.ivsize = AES_BLOCK_SIZE,
499255736Sdavidch		.base = {
500255736Sdavidch			.cra_blocksize = AES_BLOCK_SIZE,
501255736Sdavidch			.cra_ctxsize = sizeof(struct crypto_aes_ctx),
502255736Sdavidch			.cra_priority = 300,
503255736Sdavidch			.cra_name = "cbc(aes)",
504255736Sdavidch			.cra_driver_name = "cbc-aes-riscv64-zvkned",
505255736Sdavidch			.cra_module = THIS_MODULE,
506255736Sdavidch		},
507255736Sdavidch	}, {
508255736Sdavidch		.setkey = riscv64_aes_setkey_skcipher,
509255736Sdavidch		.encrypt = riscv64_aes_cbc_cts_encrypt,
510255736Sdavidch		.decrypt = riscv64_aes_cbc_cts_decrypt,
511255736Sdavidch		.min_keysize = AES_MIN_KEY_SIZE,
512255736Sdavidch		.max_keysize = AES_MAX_KEY_SIZE,
513255736Sdavidch		.ivsize = AES_BLOCK_SIZE,
514255736Sdavidch		.walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
515255736Sdavidch		.base = {
516255736Sdavidch			.cra_blocksize = AES_BLOCK_SIZE,
517255736Sdavidch			.cra_ctxsize = sizeof(struct crypto_aes_ctx),
518255736Sdavidch			.cra_priority = 300,
519255736Sdavidch			.cra_name = "cts(cbc(aes))",
520255736Sdavidch			.cra_driver_name = "cts-cbc-aes-riscv64-zvkned",
521255736Sdavidch			.cra_module = THIS_MODULE,
522255736Sdavidch		},
523255736Sdavidch	}
524255736Sdavidch};
525255736Sdavidch
526255736Sdavidchstatic struct skcipher_alg riscv64_zvkned_zvkb_aes_skcipher_alg = {
527255736Sdavidch	.setkey = riscv64_aes_setkey_skcipher,
528255736Sdavidch	.encrypt = riscv64_aes_ctr_crypt,
529255736Sdavidch	.decrypt = riscv64_aes_ctr_crypt,
530255736Sdavidch	.min_keysize = AES_MIN_KEY_SIZE,
531255736Sdavidch	.max_keysize = AES_MAX_KEY_SIZE,
532255736Sdavidch	.ivsize = AES_BLOCK_SIZE,
533255736Sdavidch	.chunksize = AES_BLOCK_SIZE,
534255736Sdavidch	.walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
535255736Sdavidch	.base = {
536255736Sdavidch		.cra_blocksize = 1,
537255736Sdavidch		.cra_ctxsize = sizeof(struct crypto_aes_ctx),
538255736Sdavidch		.cra_priority = 300,
539255736Sdavidch		.cra_name = "ctr(aes)",
540255736Sdavidch		.cra_driver_name = "ctr-aes-riscv64-zvkned-zvkb",
541255736Sdavidch		.cra_module = THIS_MODULE,
542255736Sdavidch	},
543255736Sdavidch};
544255736Sdavidch
545255736Sdavidchstatic struct skcipher_alg riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg = {
546255736Sdavidch	.setkey = riscv64_aes_xts_setkey,
547255736Sdavidch	.encrypt = riscv64_aes_xts_encrypt,
548255736Sdavidch	.decrypt = riscv64_aes_xts_decrypt,
549255736Sdavidch	.min_keysize = 2 * AES_MIN_KEY_SIZE,
550255736Sdavidch	.max_keysize = 2 * AES_MAX_KEY_SIZE,
551255736Sdavidch	.ivsize = AES_BLOCK_SIZE,
552255736Sdavidch	.chunksize = AES_BLOCK_SIZE,
553255736Sdavidch	.walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
554255736Sdavidch	.base = {
555255736Sdavidch		.cra_blocksize = AES_BLOCK_SIZE,
556255736Sdavidch		.cra_ctxsize = sizeof(struct riscv64_aes_xts_ctx),
557255736Sdavidch		.cra_priority = 300,
558255736Sdavidch		.cra_name = "xts(aes)",
559255736Sdavidch		.cra_driver_name = "xts-aes-riscv64-zvkned-zvbb-zvkg",
560255736Sdavidch		.cra_module = THIS_MODULE,
561255736Sdavidch	},
562255736Sdavidch};
563255736Sdavidch
564255736Sdavidchstatic inline bool riscv64_aes_xts_supported(void)
565255736Sdavidch{
566255736Sdavidch	return riscv_isa_extension_available(NULL, ZVBB) &&
567255736Sdavidch	       riscv_isa_extension_available(NULL, ZVKG) &&
568255736Sdavidch	       riscv_vector_vlen() < 2048 /* Implementation limitation */;
569255736Sdavidch}
570255736Sdavidch
571255736Sdavidchstatic int __init riscv64_aes_mod_init(void)
572255736Sdavidch{
573255736Sdavidch	int err = -ENODEV;
574255736Sdavidch
575255736Sdavidch	if (riscv_isa_extension_available(NULL, ZVKNED) &&
576255736Sdavidch	    riscv_vector_vlen() >= 128) {
577255736Sdavidch		err = crypto_register_alg(&riscv64_zvkned_aes_cipher_alg);
578255736Sdavidch		if (err)
579255736Sdavidch			return err;
580255736Sdavidch
581255736Sdavidch		err = crypto_register_skciphers(
582255736Sdavidch			riscv64_zvkned_aes_skcipher_algs,
583255736Sdavidch			ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
584255736Sdavidch		if (err)
585255736Sdavidch			goto unregister_zvkned_cipher_alg;
586255736Sdavidch
587255736Sdavidch		if (riscv_isa_extension_available(NULL, ZVKB)) {
588255736Sdavidch			err = crypto_register_skcipher(
589255736Sdavidch				&riscv64_zvkned_zvkb_aes_skcipher_alg);
590255736Sdavidch			if (err)
591255736Sdavidch				goto unregister_zvkned_skcipher_algs;
592255736Sdavidch		}
593255736Sdavidch
594255736Sdavidch		if (riscv64_aes_xts_supported()) {
595255736Sdavidch			err = crypto_register_skcipher(
596255736Sdavidch				&riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg);
597255736Sdavidch			if (err)
598255736Sdavidch				goto unregister_zvkned_zvkb_skcipher_alg;
599255736Sdavidch		}
600255736Sdavidch	}
601255736Sdavidch
602255736Sdavidch	return err;
603255736Sdavidch
604255736Sdavidchunregister_zvkned_zvkb_skcipher_alg:
605255736Sdavidch	if (riscv_isa_extension_available(NULL, ZVKB))
606255736Sdavidch		crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
607255736Sdavidchunregister_zvkned_skcipher_algs:
608255736Sdavidch	crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
609255736Sdavidch				    ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
610255736Sdavidchunregister_zvkned_cipher_alg:
611255736Sdavidch	crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg);
612255736Sdavidch	return err;
613255736Sdavidch}
614255736Sdavidch
615255736Sdavidchstatic void __exit riscv64_aes_mod_exit(void)
616255736Sdavidch{
617255736Sdavidch	if (riscv64_aes_xts_supported())
618255736Sdavidch		crypto_unregister_skcipher(&riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg);
619255736Sdavidch	if (riscv_isa_extension_available(NULL, ZVKB))
620255736Sdavidch		crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
621255736Sdavidch	crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
622255736Sdavidch				    ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
623255736Sdavidch	crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg);
624255736Sdavidch}
625255736Sdavidch
626255736Sdavidchmodule_init(riscv64_aes_mod_init);
627255736Sdavidchmodule_exit(riscv64_aes_mod_exit);
628255736Sdavidch
629255736SdavidchMODULE_DESCRIPTION("AES-ECB/CBC/CTS/CTR/XTS (RISC-V accelerated)");
630255736SdavidchMODULE_AUTHOR("Jerry Shih <jerry.shih@sifive.com>");
631255736SdavidchMODULE_LICENSE("GPL");
632255736SdavidchMODULE_ALIAS_CRYPTO("aes");
633255736SdavidchMODULE_ALIAS_CRYPTO("ecb(aes)");
634255736SdavidchMODULE_ALIAS_CRYPTO("cbc(aes)");
635255736SdavidchMODULE_ALIAS_CRYPTO("cts(cbc(aes))");
636255736SdavidchMODULE_ALIAS_CRYPTO("ctr(aes)");
637255736SdavidchMODULE_ALIAS_CRYPTO("xts(aes)");
638255736Sdavidch